diff --git a/README-ja.md b/README-ja.md index 865e0d35a..29c33a659 100644 --- a/README-ja.md +++ b/README-ja.md @@ -1,3 +1,7 @@ +SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 + +SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 + ## リポジトリについて Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 @@ -9,13 +13,12 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma * DreamBooth、U-NetおよびText Encoderの学習をサポート * fine-tuning、同上 +* LoRAの学習をサポート * 画像生成 * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換) ## 使用法について -当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。 - * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど * [データセット設定](./docs/config_README-ja.md) * [DreamBoothの学習について](./docs/train_db_README-ja.md) @@ -41,11 +44,13 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -以下の例ではPyTorchは1.12.1/CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください。 +スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 + +以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) -通常の(管理者ではない)PowerShellを開き以下を順に実行します。 +PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。 ```powershell git clone https://github.com/kohya-ss/sd-scripts.git @@ -54,43 +59,14 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py +pip install xformers==0.0.20 accelerate config ``` - - -コマンドプロンプトでは以下になります。 - - -```bat -git clone https://github.com/kohya-ss/sd-scripts.git -cd sd-scripts - -python -m venv venv -.\venv\Scripts\activate - -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py - -accelerate config -``` +コマンドプロンプトでも同一です。 (注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) @@ -111,29 +87,40 @@ accelerate configの質問には以下のように答えてください。(bf1 ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) -### PyTorchとxformersのバージョンについて +### オプション:`bitsandbytes`(8bit optimizer)を使う -他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。 +`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 -### オプション:Lion8bitを使う +Windowsでは0.35.0または0.41.1を推奨します。 -Lion8bitを使う場合には`bitsandbytes`を0.38.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。 +- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 +- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 -```powershell -pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl -``` +注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 -アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。 +以下の手順に従い、`bitsandbytes`をインストールしてください。 -### オプション:PagedAdamW8bitとPagedLion8bitを使う +### 0.35.0を使う場合 -PagedAdamW8bitとPagedLion8bitを使う場合には`bitsandbytes`を0.39.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。 +PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 ```powershell -pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl +cd sd-scripts +.\venv\Scripts\activate +pip install bitsandbytes==0.35.0 + +cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ +cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py +cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py ``` -アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。 +### 0.41.1を使う場合 + +jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 + +```powershell +python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui +``` ## アップグレード diff --git a/README.md b/README.md index 505994c0e..dc8e25ad6 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ +__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). + This repository contains training, generation and utility scripts for Stable Diffusion. -[__Change History__](#change-history) is moved to the bottom of the page. +[__Change History__](#change-history) is moved to the bottom of the page. 更新履歴は[ページ末尾](#change-history)に移しました。 -[日本語版README](./README-ja.md) +[日本語版READMEはこちら](./README-ja.md) For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! @@ -16,15 +18,13 @@ This repository contains the scripts for: * Image generation * Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers) -__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!! - ## About requirements.txt These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.) -The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2. +The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work. -## Links to how-to-use documents +## Links to usage documentation Most of the documents are written in Japanese. @@ -64,19 +64,20 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py +pip install xformers==0.0.20 accelerate config ``` -update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python). +__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section. + Answers to accelerate config: ```txt @@ -94,31 +95,43 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o (Single GPU with id `0` will be used.) -### about PyTorch and xformers +### Optional: Use `bitsandbytes` (8bit optimizer) + +For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.) + +For Windows, there are several versions of `bitsandbytes`: -Other versions of PyTorch and xformers seem to have problems with training. -If there is no other reason, please install the specified version. +- `bitsandbytes` 0.35.0: Stable version. AdamW8bit is available. `full_bf16` is not available. +- `bitsandbytes` 0.41.1: Lion8bit, PagedAdamW8bit and PagedLion8bit are available. `full_bf16` is available. -### Optional: Use Lion8bit +Note: `bitsandbytes`above 0.35.0 till 0.41.0 seems to have an issue: https://github.com/TimDettmers/bitsandbytes/issues/659 -For Lion8bit, you need to upgrade `bitsandbytes` to 0.38.0 or later. Uninstall `bitsandbytes`, and for Windows, install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like: +Follow the instructions below to install `bitsandbytes` for Windows. + +### bitsandbytes 0.35.0 for Windows + +Open a regular Powershell terminal and type the following inside: ```powershell -pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl +cd sd-scripts +.\venv\Scripts\activate +pip install bitsandbytes==0.35.0 + +cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ +cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py +cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py ``` -For upgrading, upgrade this repo with `pip install .`, and upgrade necessary packages manually. +This will install `bitsandbytes` 0.35.0 and copy the necessary files to the `bitsandbytes` directory. -### Optional: Use PagedAdamW8bit and PagedLion8bit +### bitsandbytes 0.41.1 for Windows -For PagedAdamW8bit and PagedLion8bit, you need to upgrade `bitsandbytes` to 0.39.0 or later. Uninstall `bitsandbytes`, and for Windows, install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like: +Install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like: ```powershell -pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl +python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui ``` -For upgrading, upgrade this repo with `pip install .`, and upgrade necessary packages manually. - ## Upgrade When a new release comes out you can upgrade your repo with the following command: @@ -148,214 +161,120 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause + +## SDXL training + +The documentation in this section will be moved to a separate document later. + +### Training scripts for SDXL + +- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. + - `--full_bf16` option is added. Thanks to KohakuBlueleaf! + - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. + - The full bfloat16 training might be unstable. Please use it at your own risk. + - The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. + - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. +- `prepare_buckets_latents.py` now supports SDXL fine-tuning. + +- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. + +- Both scripts has following additional options: + - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. + - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. + +- `--weighted_captions` option is not supported yet for both scripts. + +- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. + - `--cache_text_encoder_outputs` is not supported. + - There are two options for captions: + 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. + 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. + - See below for the format of the embeddings. + +- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. + +### Utility scripts for SDXL + +- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. + - The options are almost the same as `sdxl_train.py'. See the help message for the usage. + - Please launch the script as follows: + `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` + - This script should work with multi-GPU, but it is not tested in my environment. + +- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. + - The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage. + +- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage. + +### Tips for SDXL training + +- The default resolution of SDXL is 1024x1024. +- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. +- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use one of 8bit optimizers or Adafactor optimizer. + - Use lower dim (4 to 8 for 8GB GPU). +- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. +- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. +- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. + +Example of the optimizer settings for Adafactor with the fixed learning rate: +```toml +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +lr_scheduler = "constant_with_warmup" +lr_warmup_steps = 100 +learning_rate = 4e-7 # SDXL original learning rate +``` + +### Format of Textual Inversion embeddings for SDXL + +```python +from safetensors.torch import save_file + +state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} +save_file(state_dict, file) +``` + +### ControlNet-LLLite + +ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. + + ## Change History -### 15 Jun. 2023, 2023/06/15 - -- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds! - - Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`. -- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions. - - Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script. - - Please refer to `MinimalDataset` for implementation. I will prepare a sample later. -- The following features have been added to the generation script. - - Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny. - - Added Variants similar to sd-dynamic-propmpts in the prompt. - - If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected. - - If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected. - - If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `. - - You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`. - - It can also be specified for the prompt option. - - If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots. - - You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3. - - There is no weighting function. - -- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。 - - `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。 -- 各学習スクリプトで任意のDatasetをサポートしました(XTIを除く)。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。 - - Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。 - - 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。 -- 生成スクリプトに以下の機能追加を行いました。 - - Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。 - - プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。 - - `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。 - - `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。 - - `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。 - - 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。 - - プロンプトオプションに対しても指定可能です。 - - `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます)。X/Y plotの作成に便利かもしれません。 - - `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。 - - Weightingの機能はありません。 - -### 8 Jun. 2023, 2023/06/08 - -- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training. -- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。 - -### 6 Jun. 2023, 2023/06/06 - -- Fix `train_network.py` to probably work with older versions of LyCORIS. -- `gen_img_diffusers.py` now supports `BREAK` syntax. -- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。 -- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。 - -### 3 Jun. 2023, 2023/06/03 - -- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! - - Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details. - - Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`. - - The networks other than LoRA in this repository (such as LyCORIS) do not support this option. - -- Three types of dropout have been added to `train_network.py` and LoRA network. - - Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0. - - `--network_dropout` is a normal dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Proposed in [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! - - `--network_dropout=0.1` specifies the dropout probability to `0.1`. - - Note that the specification method is different from LyCORIS. - - For LoRA network, `--network_args` can specify `rank_dropout` to dropout each rank with specified probability. Also `module_dropout` can be specified to dropout each module with specified probability. - - Specify as `--network_args "rank_dropout=0.2" "module_dropout=0.1"`. - - `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time. - - Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified. - - `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet. - - The networks other than LoRA in this repository (such as LyCORIS) do not support these options. - -- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script. - - By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected. - - See [this article](https://xrg.hatenablog.com/entry/2023/06/02/202418) by xrg for details (written in Japanese). Thanks to xrg for the great suggestion! - -- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。 - - Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。 - - `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。 - - LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。 - -- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。 - - dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。 - - `--network_dropout` はニューロン単位の通常のdropoutです。LoRAの場合、downの出力に対して適用されます。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) で提案されました。AI-Casanova氏に感謝します。 - - `--network_dropout=0.1` などとすることで、dropoutの確率を指定できます。 - - LyCORISとは指定方法が異なりますのでご注意ください。 - - LoRAの場合、`--network_args`に`rank_dropout`を指定することで各rankを指定確率でdropoutします。また同じくLoRAの場合、`--network_args`に`module_dropout`を指定することで各モジュールを指定確率でdropoutします。 - - `--network_args "rank_dropout=0.2" "module_dropout=0.1"` のように指定します。 - - `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。 - - それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。 - - `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。 - - これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。 - -- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。 - - タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。 - - 詳細はxrg氏のこちらの記事をご参照ください:[noise_predictionモデルとv_predictionモデルの損失 - 勾配降下党青年局](https://xrg.hatenablog.com/entry/2023/06/02/202418) xrg氏の素晴らしい記事に感謝します。 - -### 31 May 2023, 2023/05/31 - -- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin! - - Warning is also displayed when using class+identifier dataset. Please ignore if it is intended. -- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru! - - `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge. - - `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used. - - This is useful for incremental learning. See PR for details. -- Show warning and continue training when uploading to HuggingFace fails. - -- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。 - - class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。 -- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。 - - `--base_weights` オプションでLoRA等のモデルファイル(複数可)を指定すると、それらの重みをマージします。 - - `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。 - - 差分追加学習などにご利用ください。詳細はPRをご覧ください。 -- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。 - -### 25 May 2023, 2023/05/25 - -- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds! - - `--optimizer_type` now accepts `DAdaptAdamPreprint`, `DAdaptAdanIP`, and `DAdaptLion`. - - `DAdaptAdam` is now new. The old `DAdaptAdam` is available with `DAdaptAdamPreprint`. - - Simply specifying `DAdaptation` will use `DAdaptAdamPreprint` (same behavior as before). - - You need to install D-Adaptation v3.0. After activating venv, please do `pip install -U dadaptation`. - - See PR and D-Adaptation documentation for details. -- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation)がサポートされました。 [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) sdbds氏に感謝します。 - - `--optimizer_type`に`DAdaptAdamPreprint`、`DAdaptAdanIP`、`DAdaptLion` が追加されました。 - - `DAdaptAdam`が新しくなりました。今までの`DAdaptAdam`は`DAdaptAdamPreprint`で使用できます。 - - 単に `DAdaptation` を指定すると`DAdaptAdamPreprint`が使用されます(今までと同じ動き)。 - - D-Adaptation v3.0のインストールが必要です。venvを有効にした後 `pip install -U dadaptation` としてください。 - - 詳細はPRおよびD-Adaptationのドキュメントを参照してください。 - -### 22 May 2023, 2023/05/22 - -- Fixed several bugs. - - The state is saved even when the `--save_state` option is not specified in `fine_tune.py` and `train_db.py`. [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) Thanks to akshaal! - - Cannot load LoRA without `alpha`. [PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Thanks to Manjiz! - - Minor changes to console output during sample generation. [PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) Thanks to yanhuifair! -- The generation script now uses xformers for VAE as well. -- いくつかのバグ修正を行いました。 - - `fine_tune.py`と`train_db.py`で`--save_state`オプション未指定時にもstateが保存される。 [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) akshaal氏に感謝します。 - - `alpha`を持たないLoRAを読み込めない。[PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Manjiz氏に感謝します。 - - サンプル生成時のコンソール出力の軽微な変更。[PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) yanhuifair氏に感謝します。 -- 生成スクリプトでVAEについてもxformersを使うようにしました。 - -### 16 May 2023, 2023/05/16 - -- Fixed an issue where an error would occur if the encoding of the prompt file was different from the default. [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) Thanks to sdbds! - - Please save the prompt file in UTF-8. -- プロンプトファイルのエンコーディングがデフォルトと異なる場合にエラーが発生する問題を修正しました。 [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) sdbds氏に感謝します。 - - プロンプトファイルはUTF-8で保存してください。 - -### 15 May 2023, 2023/05/15 - -- Added [English translation of documents](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation) by darkstorm2150. Thank you very much! -- The prompt for sample generation during training can now be specified in `.toml` or `.json`. [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Thanks to Linaqruf! - - For details on prompt description, please see the PR. - -- darkstorm2150氏に[ドキュメント類を英訳](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation)していただきました。ありがとうございます! -- 学習中のサンプル生成のプロンプトを`.toml`または`.json`で指定可能になりました。 [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Linaqruf氏に感謝します。 - - プロンプト記述の詳細は当該PRをご覧ください。 - -### 11 May 2023, 2023/05/11 - -- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova! - - It is useful in combination with `resize_lora.py`. Please see the PR for details. -- Fixed a bug where the noise resolution was incorrect with Multires noise. [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) Thanks to sdbds! - - Please see the PR for details. -- The image generation scripts can now use img2img and highres fix at the same time. -- Fixed a bug where the hint image of ControlNet was incorrectly BGR instead of RGB in the image generation scripts. -- Added a feature to the image generation scripts to use the memory-efficient VAE. - - If you specify a number with the `--vae_slices` option, the memory-efficient VAE will be used. The maximum output size will be larger, but it will be slower. Please specify a value of about `16` or `32`. - - The implementation of the VAE is in `library/slicing_vae.py`. - -- `train_network.py`にdim(rank)を重みファイルから自動決定するオプション`--dim_from_weights`が追加されました。 [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) AI-Casanova氏に感謝します。 - - `resize_lora.py`と組み合わせると有用です。詳細はPRもご参照ください。 -- Multires noiseでノイズ解像度が正しくない不具合が修正されました。 [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) sdbds氏に感謝します。 - - 詳細は当該PRをご参照ください。 -- 生成スクリプトでimg2imgとhighres fixを同時に使用できるようにしました。 -- 生成スクリプトでControlNetのhint画像が誤ってBGRだったのをRGBに修正しました。 -- 生成スクリプトで省メモリ化VAEを使えるよう機能追加しました。 - - `--vae_slices`オプションに数値を指定すると、省メモリ化VAEを用います。出力可能な最大サイズが大きくなりますが、遅くなります。`16`または`32`程度の値を指定してください。 - - VAEの実装は`library/slicing_vae.py`にあります。 - -### 7 May 2023, 2023/05/07 - -- The documentation has been moved to the `docs` folder. If you have links, please change them. -- Removed `gradio` from `requirements.txt`. -- DAdaptAdaGrad, DAdaptAdan, and DAdaptSGD are now supported by DAdaptation. [PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) Thanks to sdbds! - - DAdaptation needs to be installed. Also, depending on the optimizer, DAdaptation may need to be updated. Please update with `pip install --upgrade dadaptation`. -- Added support for pre-calculation of LoRA weights in image generation scripts. Specify `--network_pre_calc`. - - The prompt option `--am` is available. Also, it is disabled when Regional LoRA is used. -- Added Adaptive noise scale to each training script. Specify a number with `--adaptive_noise_scale` to enable it. - - __Experimental option. It may be removed or changed in the future.__ - - This is an original implementation that automatically adjusts the value of the noise offset according to the absolute value of the mean of each channel of the latents. It is expected that appropriate noise offsets will be set for bright and dark images, respectively. - - Specify it together with `--noise_offset`. - - The actual value of the noise offset is calculated as `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale`. Since the latent is close to a normal distribution, it may be a good idea to specify a value of about 1/10 to the same as the noise offset. - - Negative values can also be specified, in which case the noise offset will be clipped to 0 or more. -- Other minor fixes. - -- ドキュメントを`docs`フォルダに移動しました。リンク等を張られている場合は変更をお願いいたします。 -- `requirements.txt`から`gradio`を削除しました。 -- DAdaptationで新しくDAdaptAdaGrad、DAdaptAdan、DAdaptSGDがサポートされました。[PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) sdbds氏に感謝します。 - - dadaptationのインストールが必要です。またオプティマイザによってはdadaptationの更新が必要です。`pip install --upgrade dadaptation`で更新してください。 -- 画像生成スクリプトでLoRAの重みの事前計算をサポートしました。`--network_pre_calc`を指定してください。 - - プロンプトオプションの`--am`が利用できます。またRegional LoRA使用時には無効になります。 -- 各学習スクリプトにAdaptive noise scaleを追加しました。`--adaptive_noise_scale`で数値を指定すると有効になります。 - - __実験的オプションです。将来的に削除、仕様変更される可能性があります。__ - - Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。独自の実装で、明るい画像、暗い画像に対してそれぞれ適切なnoise offsetが設定されることが期待されます。 - - `--noise_offset` と同時に指定してください。 - - 実際のNoise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。 latentは正規分布に近いためnoise_offsetの1/10~同程度の値を指定するとよいかもしれません。 - - 負の値も指定でき、その場合はnoise offsetは0以上にclipされます。 -- その他の細かい修正を行いました。 +### Oct 1. 2023 / 2023/10/1 + +- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch. + +- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet. + - The main items are set automatically. + - You can set title, author, description, license and tags with `--metadata_xxx` options in each training script. + - Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage. + - Metadata editor will be available soon. + +- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section. + +- `albumentations` is not required anymore. + +- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled). + - In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps. + +- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`. + +- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model. +- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786) +- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details. +- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`. +- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825) +- Other bug fixes and improvements. + Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 diff --git a/XTI_hijack.py b/XTI_hijack.py index f39cc8e7e..ec0849455 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,133 +1,131 @@ import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput -def unet_forward_XTI(self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - if self.config.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - down_i = 0 - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], - ) - down_i += 2 - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) - - # 5. up - up_i = 7 - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], - upsample_size=upsample_size, - ) - up_i += 3 - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) +from library.original_unet import SampleOutput + + +def unet_forward_XTI( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, +) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + down_i = 0 + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], + ) + down_i += 2 + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) + + # 5. up + up_i = 7 + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], + upsample_size=upsample_size, + ) + up_i += 3 + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) + def downblock_forward_XTI( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None @@ -166,6 +164,7 @@ def custom_forward(*inputs): return hidden_states, output_states + def upblock_forward_XTI( self, hidden_states, @@ -199,11 +198,11 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample - + i += 1 if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/bitsandbytes_windows/libbitsandbytes_cuda118.dll b/bitsandbytes_windows/libbitsandbytes_cuda118.dll new file mode 100644 index 000000000..a54cc960b Binary files /dev/null and b/bitsandbytes_windows/libbitsandbytes_cuda118.dll differ diff --git a/bitsandbytes_windows/main.py b/bitsandbytes_windows/main.py index 7e5f9c981..380f85aec 100644 --- a/bitsandbytes_windows/main.py +++ b/bitsandbytes_windows/main.py @@ -1,166 +1,166 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - -import ctypes - -from .paths import determine_cuda_runtime_lib_path - - -def check_cuda_result(cuda, result_val): - # 3. Check for CUDA errors - if result_val != 0: - error_str = ctypes.c_char_p() - cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) - print(f"CUDA exception! Error code: {error_str.value.decode()}") - -def get_cuda_version(cuda, cudart_path): - # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION - try: - cudart = ctypes.CDLL(cudart_path) - except OSError: - # TODO: shouldn't we error or at least warn here? - print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') - return None - - version = ctypes.c_int() - check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) - version = int(version.value) - major = version//1000 - minor = (version-(major*1000))//10 - - if major < 11: - print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') - - return f'{major}{minor}' - - -def get_cuda_lib_handle(): - # 1. find libcuda.so library (GPU driver) (/usr/lib) - try: - cuda = ctypes.CDLL("libcuda.so") - except OSError: - # TODO: shouldn't we error or at least warn here? - print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') - return None - check_cuda_result(cuda, cuda.cuInit(0)) - - return cuda - - -def get_compute_capabilities(cuda): - """ - 1. find libcuda.so library (GPU driver) (/usr/lib) - init_device -> init variables -> call function by reference - 2. call extern C function to determine CC - (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) - 3. Check for CUDA errors - https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api - # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 - """ - - - nGpus = ctypes.c_int() - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() - - device = ctypes.c_int() - - check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) - ccs = [] - for i in range(nGpus.value): - check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) - ref_major = ctypes.byref(cc_major) - ref_minor = ctypes.byref(cc_minor) - # 2. call extern C function to determine CC - check_cuda_result( - cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) - ) - ccs.append(f"{cc_major.value}.{cc_minor.value}") - - return ccs - - -# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error -def get_compute_capability(cuda): - """ - Extracts the highest compute capbility from all available GPUs, as compute - capabilities are downwards compatible. If no GPUs are detected, it returns - None. - """ - ccs = get_compute_capabilities(cuda) - if ccs is not None: - # TODO: handle different compute capabilities; for now, take the max - return ccs[-1] - return None - - -def evaluate_cuda_setup(): - print('') - print('='*35 + 'BUG REPORT' + '='*35) - print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') - print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') - print('='*80) - return "libbitsandbytes_cuda116.dll" # $$$ - - binary_name = "libbitsandbytes_cpu.so" - #if not torch.cuda.is_available(): - #print('No GPU detected. Loading CPU library...') - #return binary_name - - cudart_path = determine_cuda_runtime_lib_path() - if cudart_path is None: - print( - "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" - ) - return binary_name - - print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") - cuda = get_cuda_lib_handle() - cc = get_compute_capability(cuda) - print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") - cuda_version_string = get_cuda_version(cuda, cudart_path) - - - if cc == '': - print( - "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." - ) - return binary_name - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = cc in ["7.5", "8.0", "8.6"] - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - # we use ls -l instead of nvcc to determine the cuda version - # since most installations will have the libcudart.so installed, but not the compiler - print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') - - def get_binary_name(): - "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" - bin_base_name = "libbitsandbytes_cuda" - if has_cublaslt: - return f"{bin_base_name}{cuda_version_string}.so" - else: - return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" - - binary_name = get_binary_name() - - return binary_name +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + +import ctypes + +from .paths import determine_cuda_runtime_lib_path + + +def check_cuda_result(cuda, result_val): + # 3. Check for CUDA errors + if result_val != 0: + error_str = ctypes.c_char_p() + cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) + print(f"CUDA exception! Error code: {error_str.value.decode()}") + +def get_cuda_version(cuda, cudart_path): + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION + try: + cudart = ctypes.CDLL(cudart_path) + except OSError: + # TODO: shouldn't we error or at least warn here? + print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') + return None + + version = ctypes.c_int() + check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) + version = int(version.value) + major = version//1000 + minor = (version-(major*1000))//10 + + if major < 11: + print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') + + return f'{major}{minor}' + + +def get_cuda_lib_handle(): + # 1. find libcuda.so library (GPU driver) (/usr/lib) + try: + cuda = ctypes.CDLL("libcuda.so") + except OSError: + # TODO: shouldn't we error or at least warn here? + print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') + return None + check_cuda_result(cuda, cuda.cuInit(0)) + + return cuda + + +def get_compute_capabilities(cuda): + """ + 1. find libcuda.so library (GPU driver) (/usr/lib) + init_device -> init variables -> call function by reference + 2. call extern C function to determine CC + (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) + 3. Check for CUDA errors + https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api + # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 + """ + + + nGpus = ctypes.c_int() + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + + device = ctypes.c_int() + + check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) + ccs = [] + for i in range(nGpus.value): + check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) + ref_major = ctypes.byref(cc_major) + ref_minor = ctypes.byref(cc_minor) + # 2. call extern C function to determine CC + check_cuda_result( + cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) + ) + ccs.append(f"{cc_major.value}.{cc_minor.value}") + + return ccs + + +# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error +def get_compute_capability(cuda): + """ + Extracts the highest compute capbility from all available GPUs, as compute + capabilities are downwards compatible. If no GPUs are detected, it returns + None. + """ + ccs = get_compute_capabilities(cuda) + if ccs is not None: + # TODO: handle different compute capabilities; for now, take the max + return ccs[-1] + return None + + +def evaluate_cuda_setup(): + print('') + print('='*35 + 'BUG REPORT' + '='*35) + print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') + print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') + print('='*80) + return "libbitsandbytes_cuda116.dll" # $$$ + + binary_name = "libbitsandbytes_cpu.so" + #if not torch.cuda.is_available(): + #print('No GPU detected. Loading CPU library...') + #return binary_name + + cudart_path = determine_cuda_runtime_lib_path() + if cudart_path is None: + print( + "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" + ) + return binary_name + + print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") + cuda = get_cuda_lib_handle() + cc = get_compute_capability(cuda) + print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") + cuda_version_string = get_cuda_version(cuda, cudart_path) + + + if cc == '': + print( + "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." + ) + return binary_name + + # 7.5 is the minimum CC vor cublaslt + has_cublaslt = cc in ["7.5", "8.0", "8.6"] + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + # we use ls -l instead of nvcc to determine the cuda version + # since most installations will have the libcudart.so installed, but not the compiler + print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') + + def get_binary_name(): + "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" + bin_base_name = "libbitsandbytes_cuda" + if has_cublaslt: + return f"{bin_base_name}{cuda_version_string}.so" + else: + return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" + + binary_name = get_binary_name() + + return binary_name diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 7f2b6c4c1..69a03f6cf 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -138,9 +138,13 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `num_repeats` | `10` | o | o | o | | `random_crop` | `false` | o | o | o | | `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | +| `caption_suffix` | `“, from side”` | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 +* `caption_prefix`, `caption_suffix` + * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。 ### DreamBooth 方式専用のオプション diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md new file mode 100644 index 000000000..dbdc1fea2 --- /dev/null +++ b/docs/train_lllite_README-ja.md @@ -0,0 +1,214 @@ +# ControlNet-LLLite について + +__きわめて実験的な実装のため、将来的に大きく変更される可能性があります。__ + +## 概要 +ControlNet-LLLite は、[ControlNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。 + +## サンプルの重みファイルと推論 + +こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite + +ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +生成サンプルはこのページの末尾にあります。 + +## モデル構造 +ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 + +推論環境の制限で、現在はCrossAttentionのみ(attn1のq/k/v、attn2のq)に追加されます。 + +## モデルの学習 + +### データセットの準備 +通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 + +たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。 + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +現時点の制約として、random_cropは使用できません。 + +学習データとしては、元のモデルで生成した画像を学習用画像として、そこから加工した画像をconditioning imageとした、合成によるデータセットを用いるのがもっとも簡単です(データセットの品質的には問題があるかもしれません)。具体的なデータセットの合成方法については後述します。 + +なお、元モデルと異なる画風の画像を学習用画像とすると、制御に加えて、その画風についても学ぶ必要が生じます。ControlNet-LLLiteは容量が少ないため、画風学習には不向きです。このような場合には、後述の次元数を多めにしてください。 + +### 学習 +スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 + +学習時にはメモリを大量に使用しますので、キャッシュやgradient checkpointingなどの省メモリ化のオプションを有効にしてください。また`--full_bf16` オプションで、BFloat16を使用するのも有効です(RTX 30シリーズ以降のGPUが必要です)。24GB VRAMで動作確認しています。 + +conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。 + +(サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。) + +以下は .toml の設定例です。 + +```toml +pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors" +max_train_epochs = 12 +max_data_loader_n_workers = 4 +persistent_data_loader_workers = true +seed = 42 +gradient_checkpointing = true +mixed_precision = "bf16" +save_precision = "bf16" +full_bf16 = true +optimizer_type = "adamw8bit" +learning_rate = 2e-4 +xformers = true +output_dir = "/path/to/output/dir" +output_name = "output_name" +save_every_n_epochs = 1 +save_model_as = "safetensors" +vae_batch_size = 4 +cache_latents = true +cache_latents_to_disk = true +cache_text_encoder_outputs = true +cache_text_encoder_outputs_to_disk = true +network_dim = 64 +cond_emb_dim = 32 +dataset_config = "/path/to/dataset.toml" +``` + +### 推論 + +スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 + +`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 + +## データセットの合成方法 + +### 学習用画像の生成 + +学習のベースとなるモデルで画像生成を行います。Web UIやComfyUIなどで生成してください。画像サイズはモデルのデフォルトサイズで良いと思われます(1024x1024など)。bucketingを用いることもできます。その場合は適宜適切な解像度で生成してください。 + +生成時のキャプション等は、ControlNet-LLLiteの利用時に生成したい画像にあわせるのが良いと思われます。 + +生成した画像を任意のディレクトリに保存してください。このディレクトリをデータセットの設定ファイルで指定します。 + +当リポジトリ内の `sdxl_gen_img.py` でも生成できます。例えば以下のように実行します。 + +```dos +python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn" +``` + +VRAM 24GBの設定です。VRAMサイズにより`--batch_size` `--vae_batch_size`を調整してください。 + +`--prompt`でワイルドカードを利用してランダムに生成しています。適宜調整してください。 + +### 画像の加工 + +外部のプログラムを用いて、生成した画像を加工します。加工した画像を任意のディレクトリに保存してください。これらがconditioning imageになります。 + +加工にはたとえばCannyなら以下のようなスクリプトが使えます。 + +```python +import glob +import os +import random +import cv2 +import numpy as np + +IMAGES_DIR = "path/to/generated/images" +CANNY_DIR = "path/to/canny/images" + +os.makedirs(CANNY_DIR, exist_ok=True) +img_files = glob.glob(IMAGES_DIR + "/*.png") +for img_file in img_files: + can_file = CANNY_DIR + "/" + os.path.basename(img_file) + if os.path.exists(can_file): + print("Skip: " + img_file) + continue + + print(img_file) + + img = cv2.imread(img_file) + + # random threshold + # while True: + # threshold1 = random.randint(0, 127) + # threshold2 = random.randint(128, 255) + # if threshold2 - threshold1 > 80: + # break + + # fixed threshold + threshold1 = 100 + threshold2 = 200 + + img = cv2.Canny(img, threshold1, threshold2) + + cv2.imwrite(can_file, img) +``` + +### キャプションファイルの作成 + +学習用画像のbasenameと同じ名前で、それぞれの画像に対応したキャプションファイルを作成してください。生成時のプロンプトをそのまま利用すれば良いと思われます。 + +`sdxl_gen_img.py` で生成した場合は、画像内のメタデータに生成時のプロンプトが記録されていますので、以下のようなスクリプトで学習用画像と同じディレクトリにキャプションファイルを作成できます(拡張子 `.txt`)。 + +```python +import glob +import os +from PIL import Image + +IMAGES_DIR = "path/to/generated/images" + +img_files = glob.glob(IMAGES_DIR + "/*.png") +for img_file in img_files: + cap_file = img_file.replace(".png", ".txt") + if os.path.exists(cap_file): + print(f"Skip: {img_file}") + continue + print(img_file) + + img = Image.open(img_file) + prompt = img.text["prompt"] if "prompt" in img.text else "" + if prompt == "": + print(f"Prompt not found in {img_file}") + + with open(cap_file, "w") as f: + f.write(prompt + "\n") +``` + +### データセットの設定ファイルの作成 + +コマンドラインオプションからの指定も可能ですが、`.toml`ファイルを作成する場合は `conditioning_data_dir` に加工した画像を保存したディレクトリを指定します。 + +以下は設定ファイルの例です。 + +```toml +[general] +flip_aug = false +color_aug = false +resolution = [1024,1024] + +[[datasets]] +batch_size = 8 +enable_bucket = false + + [[datasets.subsets]] + image_dir = "path/to/generated/image/dir" + caption_extension = ".txt" + conditioning_data_dir = "path/to/canny/image/dir" +``` + +## 謝辞 + +ControlNetの作者である lllyasviel 氏、実装上のアドバイスとトラブル解決へのご尽力をいただいた furusu 氏、ControlNetデータセットを実装していただいた ddPn08 氏に感謝いたします。 + +## サンプル +Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) + +![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) + +![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) + +![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) + diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md new file mode 100644 index 000000000..04dc12da2 --- /dev/null +++ b/docs/train_lllite_README.md @@ -0,0 +1,217 @@ +# About ControlNet-LLLite + +__This is an extremely experimental implementation and may change significantly in the future.__ + +日本語版は[こちら](./train_lllite_README-ja.md) + +## Overview + +ControlNet-LLLite is a lightweight version of [ControlNet](https://github.com/lllyasviel/ControlNet). It is a "LoRA Like Lite" that is inspired by LoRA and has a lightweight structure. Currently, only SDXL is supported. + +## Sample weight file and inference + +Sample weight file is available here: https://huggingface.co/kohya-ss/controlnet-lllite + +A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +Sample images are at the end of this page. + +## Model structure + +A single LLLite module consists of a conditioning image embedding that maps a conditioning image to a latent space and a small network with a structure similar to LoRA. The LLLite module is added to U-Net's Linear and Conv in the same way as LoRA. Please refer to the source code for details. + +Due to the limitations of the inference environment, only CrossAttention (attn1 q/k/v, attn2 q) is currently added. + +## Model training + +### Preparing the dataset + +In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file. + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +At the moment, random_crop cannot be used. + +For training data, it is easiest to use a synthetic dataset with the original model-generated images as training images and processed images as conditioning images (the quality of the dataset may be problematic). See below for specific methods of synthesizing datasets. + +Note that if you use an image with a different art style than the original model as a training image, the model will have to learn not only the control but also the art style. ControlNet-LLLite has a small capacity, so it is not suitable for learning art styles. In such cases, increase the number of dimensions as described below. + +### Training + +Run `sdxl_train_control_net_lllite.py`. You can specify the dimension of the conditioning image embedding with `--cond_emb_dim`. You can specify the rank of the LoRA-like module with `--network_dim`. Other options are the same as `sdxl_train_network.py`, but `--network_module` is not required. + +Since a large amount of memory is used during training, please enable memory-saving options such as cache and gradient checkpointing. It is also effective to use BFloat16 with the `--full_bf16` option (requires RTX 30 series or later GPU). It has been confirmed to work with 24GB VRAM. + +For the sample Canny, the dimension of the conditioning image embedding is 32. The rank of the LoRA-like module is also 64. Adjust according to the features of the conditioning image you are targeting. + +(The sample Canny is probably quite difficult. It may be better to reduce it to about half for depth, etc.) + +The following is an example of a .toml configuration. + +```toml +pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors" +max_train_epochs = 12 +max_data_loader_n_workers = 4 +persistent_data_loader_workers = true +seed = 42 +gradient_checkpointing = true +mixed_precision = "bf16" +save_precision = "bf16" +full_bf16 = true +optimizer_type = "adamw8bit" +learning_rate = 2e-4 +xformers = true +output_dir = "/path/to/output/dir" +output_name = "output_name" +save_every_n_epochs = 1 +save_model_as = "safetensors" +vae_batch_size = 4 +cache_latents = true +cache_latents_to_disk = true +cache_text_encoder_outputs = true +cache_text_encoder_outputs_to_disk = true +network_dim = 64 +cond_emb_dim = 32 +dataset_config = "/path/to/dataset.toml" +``` + +### Inference + +If you want to generate images with a script, run `sdxl_gen_img.py`. You can specify the LLLite model file with `--control_net_lllite_models`. The dimension is automatically obtained from the model file. + +Specify the conditioning image to be used for inference with `--guide_image_path`. Since preprocess is not performed, if it is Canny, specify an image processed with Canny (white line on black background). `--control_net_preps`, `--control_net_weights`, and `--control_net_ratios` are not supported. + +## How to synthesize a dataset + +### Generating training images + +Generate images with the base model for training. Please generate them with Web UI or ComfyUI etc. The image size should be the default size of the model (1024x1024, etc.). You can also use bucketing. In that case, please generate it at an arbitrary resolution. + +The captions and other settings when generating the images should be the same as when generating the images with the trained ControlNet-LLLite model. + +Save the generated images in an arbitrary directory. Specify this directory in the dataset configuration file. + + +You can also generate them with `sdxl_gen_img.py` in this repository. For example, run as follows: + +```dos +python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn" +``` + +This is a setting for VRAM 24GB. Adjust `--batch_size` and `--vae_batch_size` according to the VRAM size. + +The images are generated randomly using wildcards in `--prompt`. Adjust as necessary. + +### Processing images + +Use an external program to process the generated images. Save the processed images in an arbitrary directory. These will be the conditioning images. + +For example, you can use the following script to process the images with Canny. + +```python +import glob +import os +import random +import cv2 +import numpy as np + +IMAGES_DIR = "path/to/generated/images" +CANNY_DIR = "path/to/canny/images" + +os.makedirs(CANNY_DIR, exist_ok=True) +img_files = glob.glob(IMAGES_DIR + "/*.png") +for img_file in img_files: + can_file = CANNY_DIR + "/" + os.path.basename(img_file) + if os.path.exists(can_file): + print("Skip: " + img_file) + continue + + print(img_file) + + img = cv2.imread(img_file) + + # random threshold + # while True: + # threshold1 = random.randint(0, 127) + # threshold2 = random.randint(128, 255) + # if threshold2 - threshold1 > 80: + # break + + # fixed threshold + threshold1 = 100 + threshold2 = 200 + + img = cv2.Canny(img, threshold1, threshold2) + + cv2.imwrite(can_file, img) +``` + +### Creating caption files + +Create a caption file for each image with the same basename as the training image. It is fine to use the same caption as the one used when generating the image. + +If you generated the images with `sdxl_gen_img.py`, you can use the following script to create the caption files (`*.txt`) from the metadata in the generated images. + +```python +import glob +import os +from PIL import Image + +IMAGES_DIR = "path/to/generated/images" + +img_files = glob.glob(IMAGES_DIR + "/*.png") +for img_file in img_files: + cap_file = img_file.replace(".png", ".txt") + if os.path.exists(cap_file): + print(f"Skip: {img_file}") + continue + print(img_file) + + img = Image.open(img_file) + prompt = img.text["prompt"] if "prompt" in img.text else "" + if prompt == "": + print(f"Prompt not found in {img_file}") + + with open(cap_file, "w") as f: + f.write(prompt + "\n") +``` + +### Creating a dataset configuration file + +You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`. + +```toml +[general] +flip_aug = false +color_aug = false +resolution = [1024,1024] + +[[datasets]] +batch_size = 8 +enable_bucket = false + + [[datasets.subsets]] + image_dir = "path/to/generated/image/dir" + caption_extension = ".txt" + conditioning_data_dir = "path/to/canny/image/dir" +``` + +## Credit + +I would like to thank lllyasviel, the author of ControlNet, furusu, who provided me with advice on implementation and helped me solve problems, and ddPn08, who implemented the ControlNet dataset. + +## Sample + +Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) + +![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) + +![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) + +![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index e620a8642..2205a7736 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -181,6 +181,8 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf 詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。 +SDXLは現在サポートしていません。 + フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 `--network_args` で以下の引数を指定してください。 @@ -246,6 +248,8 @@ network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8, merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。 +SDXL向けにはsdxl_merge_lora.pyを用意しています。オプション等は同一ですので、以下のmerge_lora.pyを読み替えてください。 + ### Stable DiffusionのモデルにLoRAのモデルをマージする マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。 @@ -276,29 +280,29 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt ### 複数のLoRAのモデルをマージする -__複数のLoRAをマージする場合は原則として `svd_merge_lora.py` を使用してください。__ 単純なup同士やdown同士のマージでは、計算結果が正しくなくなるためです。 - -`merge_lora.py` によるマージは差分抽出法でLoRAを生成する場合等、ごく限られた場合でのみ有効です。 +--concatオプションを指定すると、複数のLoRAを単純に結合して新しいLoRAモデルを作成できます。ファイルサイズ(およびdim/rank)は指定したLoRAの合計サイズになります(マージ時にdim (rank)を変更する場合は `svd_merge_lora.py` を使用してください)。 たとえば以下のようなコマンドラインになります。 ``` -python networks\merge_lora.py +python networks\merge_lora.py --save_precision bf16 --save_to ..\lora_train1\model-char1-style1-merged.safetensors - --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4 + --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors + --ratios 1.0 -1.0 --concat --shuffle ``` ---sd_modelオプションは指定不要です。 +--concatオプションを指定します。 + +また--shuffleオプションを追加し、重みをシャッフルします。シャッフルしないとマージ後のLoRAから元のLoRAを取り出せるため、コピー機学習などの場合には学習元データが明らかになります。ご注意ください。 --save_toオプションにマージ後のLoRAモデルの保存先を指定します(.ckptまたは.safetensors、拡張子で自動判定)。 --modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。 ---ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。 +--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージする場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。 v1で学習したLoRAとv2で学習したLoRA、rank(次元数)の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。 - ### その他のオプション * precision @@ -306,6 +310,7 @@ v1で学習したLoRAとv2で学習したLoRA、rank(次元数)の異なるL * save_precision * モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。 +他にもいくつかのオプションがありますので、--helpで確認してください。 ## 複数のrankが異なるLoRAのモデルをマージする diff --git a/fine_tune.py b/fine_tune.py index fbb9e54c4..f300d4688 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,13 +5,19 @@ import gc import math import os -import toml from multiprocessing import Value +import toml from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -25,8 +31,6 @@ apply_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) @@ -44,7 +48,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -95,7 +99,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -139,13 +143,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # モデルに xformers とか memory efficient attention を組み込む if args.diffusers_xformers: - print("Use xformers by Diffusers") + accelerator.print("Use xformers by Diffusers") set_diffusers_xformers_flag(unet, True) else: # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - print("Disable Diffusers' xformers") + accelerator.print("Disable Diffusers' xformers") set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: @@ -168,7 +172,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): training_models.append(unet) if args.train_text_encoder: - print("enable text encoder training") + accelerator.print("enable text encoder training") if args.gradient_checkpointing: text_encoder.gradient_checkpointing_enable() training_models.append(text_encoder) @@ -194,7 +198,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): params_to_optimize = params # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する @@ -214,7 +218,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -227,7 +231,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") unet.to(weight_dtype) text_encoder.to(weight_dtype) @@ -257,14 +261,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -273,12 +279,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 for m in training_models: @@ -314,20 +325,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype ) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -389,15 +389,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) @@ -432,8 +434,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -441,8 +443,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): is_main_process = accelerator.is_main_process if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index fd289d1d3..af08c5375 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -34,16 +34,7 @@ def collate_fn_remove_corrupted(batch): return batch -def get_latents(vae, images, weight_dtype): - img_tensors = [IMAGE_TRANSFORMS(image) for image in images] - img_tensors = torch.stack(img_tensors) - img_tensors = img_tensors.to(DEVICE, weight_dtype) - with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() - return latents - - -def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): +def get_npz_filename(data_dir, image_key, is_full_path, recursive): if is_full_path: base_name = os.path.splitext(os.path.basename(image_key))[0] relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) @@ -51,19 +42,20 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): base_name = image_key relative_path = "" - if flip: - base_name += "_flip" - if recursive and relative_path: - return os.path.join(data_dir, relative_path, base_name) + return os.path.join(data_dir, relative_path, base_name) + ".npz" else: - return os.path.join(data_dir, base_name) + return os.path.join(data_dir, base_name) + ".npz" def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + if args.bucket_reso_steps % 32 > 0: + print( + f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" + ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] @@ -107,34 +99,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - latents = get_latents(vae, [img for _, img in bucket], weight_dtype) - assert ( - latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8 - ), f"latent shape {latents.shape}, {bucket[0][1].shape}" - - for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) - np.savez(npz_file_name, latent) - - # flip - if args.flip_aug: - latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない - - for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext( - args.train_data_dir, image_key, args.full_path, True, args.recursive - ) - np.savez(npz_file_name, latent) - else: - # remove existing flipped npz - for image_key, _ in bucket: - npz_file_name = ( - get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" - ) - if os.path.isfile(npz_file_name): - print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") - os.remove(npz_file_name) - + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -194,50 +159,19 @@ def process_batch(is_last): resized_size[0] >= reso[0] and resized_size[1] >= reso[1] ), f"internal error resized size is small: {resized_size}, {reso}" - # 既に存在するファイルがあればshapeを確認して同じならskipする + # 既に存在するファイルがあればshape等を確認して同じならskipする + npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive) if args.skip_existing: - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] - if args.flip_aug: - npz_files.append( - get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" - ) - - found = True - for npz_file in npz_files: - if not os.path.exists(npz_file): - found = False - break - - dat = np.load(npz_file)["arr_0"] - if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 - found = False - break - if found: + if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug): continue - # 画像をリサイズしてトリミングする - # PILにinter_areaがないのでcv2で…… - image = np.array(image) - if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) - - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size // 2 : trim_size // 2 + reso[0]] - - if resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size // 2 : trim_size // 2 + reso[1]] - - assert ( - image.shape[0] == reso[1] and image.shape[1] == reso[0] - ), f"internal error, illegal trimmed size: {image.shape}, {reso}" - - # # debug - # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) - # バッチへ追加 - bucket_manager.add_image(reso, (image_key, image)) + image_info = train_util.ImageInfo(image_key, 1, "", False, image_path) + image_info.latents_npz = npz_file_name + image_info.bucket_reso = reso + image_info.resized_size = resized_size + image_info.image = image + bucket_manager.add_image(reso, image_info) # バッチを推論するか判定して推論する process_batch(False) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 9ac5cd177..70ca67942 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,6 +65,13 @@ import diffusers import numpy as np import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass import torchvision from diffusers import ( AutoencoderKL, @@ -79,11 +86,10 @@ HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, - UNet2DConditionModel, + # UNet2DConditionModel, StableDiffusionPipeline, ) from einops import rearrange -from torch import einsum from tqdm import tqdm from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig @@ -96,15 +102,11 @@ from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel +from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う - -DEFAULT_TOKEN_LENGTH = 75 - # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -136,339 +138,153 @@ 高速化のためのモジュール入れ替え """ -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - return dq, dk, dv, None, None, None, None + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) # TODO common train_util.py -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() + replace_vae_attn_to_memory_efficient() elif xformers: - replace_unet_cross_attn_to_xformers() + replace_vae_attn_to_xformers() + elif sdpa: + replace_vae_attn_to_sdpa() -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork") +def replace_vae_attn_to_memory_efficient(): + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): + def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + residual = hidden_states + batch, channel, height, width = hidden_states.shape - context = context if context is not None else x - context = context.to(x.dtype) + # norm + hidden_states = self.group_norm(hidden_states) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) -def replace_unet_cross_attn_to_xformers(): - print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states - context = default(context, x) - context = context.to(x.dtype) + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + # norm + hidden_states = self.group_norm(hidden_states) - out = rearrange(out, "b n h d -> b n (h d)", h=h) + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) - diffusers.models.attention.CrossAttention.forward = forward_xformers + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") - vae.set_use_memory_efficient_attention_xformers(True) + out = rearrange(out, "b h n d -> b n (h d)") - """ - # VAEがbfloat16でメモリ消費が大きい問題を解決する - upsamplers = [] - for block in vae.decoder.up_blocks: - if block.upsamplers is not None: - upsamplers.extend(block.upsamplers) - - def forward_upsample(_self, hidden_states, output_size=None): - assert hidden_states.shape[1] == _self.channels - if _self.use_conv_transpose: - return _self.conv(hidden_states) - - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - assert output_size is None - # repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する - hidden_states = hidden_states.repeat_interleave(2, dim=-1) - hidden_states = hidden_states.repeat_interleave(2, dim=-2) - else: - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest") + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - if _self.use_conv: - if _self.name == "conv": - hidden_states = _self.conv(hidden_states) - else: - hidden_states = _self.Conv2d_0(hidden_states) + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - # replace upsamplers - for upsampler in upsamplers: - # make new scope - def make_replacer(upsampler): - def forward(hidden_states, output_size=None): - return forward_upsample(upsampler, hidden_states, output_size) - - return forward - - upsampler.forward = make_replacer(upsampler) -""" + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers -def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") - q_bucket_size = 512 - k_bucket_size = 1024 +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -478,27 +294,45 @@ def forward_flash_attn(self, hidden_states): hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) ) - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) - out = rearrange(out, "b h n d -> b n (h d)") + out = rearrange(out, "b n h d -> b n (h d)") # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa # endregion @@ -1110,6 +944,17 @@ def __call__( if self.control_nets: guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if reginonal_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # last subprompt and negative prompt + text_emb_last = [] + for j in range(batch_size): + text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 2]) + text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 1]) + text_emb_last = torch.stack(text_emb_last) + else: + text_emb_last = text_embeddings + for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) @@ -1117,11 +962,6 @@ def __call__( # predict the noise residual if self.control_nets and self.control_net_enabled: - if reginonal_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - else: - text_emb_last = text_embeddings noise_pred = original_control_net.call_unet_and_control_net( i, num_latent_input, @@ -1131,6 +971,7 @@ def __call__( i / len(timesteps), latent_model_input, t, + text_embeddings, text_emb_last, ).sample else: @@ -2342,6 +2183,17 @@ def main(args): tokenizer = loading_pipe.tokenizer del loading_pipe + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) @@ -2366,8 +2218,9 @@ def main(args): # xformers、Hypernetwork対応 if not args.diffusers_xformers: - replace_unet_modules(unet, not args.xformers, args.xformers) - replace_vae_modules(vae, not args.xformers, args.xformers) + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む print("loading tokenizer") @@ -2907,6 +2760,10 @@ def resize_images(imgs, size): print(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) + # shuffle prompt list + if args.shuffle_prompts: + random.shuffle(prompt_list) + # バッチ処理の関数 def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): batch_size = len(batch) @@ -3124,6 +2981,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( zip(images, prompts, negative_prompts, seeds, clip_prompts) ): + if highres_fix: + seed -= 1 # record original seed metadata = PngInfo() metadata.add_text("prompt", prompt) metadata.add_text("seed", str(seed)) @@ -3293,7 +3152,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): print("predefined seeds are exhausted") seed = None elif args.iter_same_seed: - seeds = iter_seed + seed = iter_seed else: seed = None # 前のを消す @@ -3480,9 +3339,15 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", ) + parser.add_argument( + "--shuffle_prompts", + action="store_true", + help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", + ) parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") parser.add_argument( "--diffusers_xformers", action="store_true", diff --git a/library/attention_processors.py b/library/attention_processors.py new file mode 100644 index 000000000..310c2cb1c --- /dev/null +++ b/library/attention_processors.py @@ -0,0 +1,227 @@ +import math +from typing import Any +from einops import rearrange +import torch +from diffusers.models.attention_processor import Attention + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + +EPSILON = 1e-6 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full( + (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device + ) + + scale = q.shape[-1] ** -0.5 + + if mask is None: + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if row_mask is not None: + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if row_mask is not None: + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( + min=EPSILON + ) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum( + "... i j, ... j d -> ... i d", exp_weights, vc + ) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = ( + exp_row_max_diff * row_sums + + exp_block_row_max_diff * block_row_sums + ) + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( + (exp_block_row_max_diff / new_row_sums) * exp_values + ) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if row_mask is not None: + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +class FlashAttnProcessor: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ) -> Any: + q_bucket_size = 512 + k_bucket_size = 1024 + + h = attn.heads + q = attn.to_q(hidden_states) + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: + context_k, context_v = attn.hypernetwork.forward( + hidden_states, encoder_hidden_states + ) + context_k = context_k.to(hidden_states.dtype) + context_v = context_v.to(hidden_states.dtype) + else: + context_k = encoder_hidden_states + context_v = encoder_hidden_states + + k = attn.to_k(context_k) + v = attn.to_v(context_v) + del encoder_hidden_states, hidden_states + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = FlashAttentionFunction.apply( + q, k, v, attention_mask, False, q_bucket_size, k_bucket_size + ) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out diff --git a/library/config_util.py b/library/config_util.py index 98b417516..e8e0fda7c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -33,8 +33,10 @@ from .train_util import ( DreamBoothSubset, FineTuningSubset, + ControlNetSubset, DreamBoothDataset, FineTuningDataset, + ControlNetDataset, DatasetGroup, ) @@ -54,6 +56,8 @@ class BaseSubsetParams: flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 @@ -70,9 +74,14 @@ class DreamBoothSubsetParams(BaseSubsetParams): class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None +@dataclass +class ControlNetSubsetParams(BaseSubsetParams): + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: - tokenizer: CLIPTokenizer = None + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False @@ -96,6 +105,15 @@ class FineTuningDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False +@dataclass +class ControlNetDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] @@ -103,6 +121,7 @@ class SubsetBlueprint: @dataclass class DatasetBlueprint: is_dreambooth: bool + is_controlnet: bool params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] subsets: Sequence[SubsetBlueprint] @@ -142,6 +161,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "keep_tokens": int, "token_warmup_min": int, "token_warmup_step": Any(float,int), + "caption_prefix": str, + "caption_suffix": str, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -163,6 +184,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] Required("metadata_file"): str, "image_dir": str, } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } # datasets schema DATASET_ASCENDABLE_SCHEMA = { @@ -192,8 +220,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "dataset_repeats": "num_repeats", } - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -208,6 +236,13 @@ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_d self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + self.db_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, @@ -223,13 +258,23 @@ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_d {"subsets": [self.ft_subset_schema]}, ) + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + if support_dreambooth and support_finetuning: def validate_flex_dataset(dataset_config: dict): subsets_config = dataset_config.get("subsets", []) + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) # check dataset meets FT style # NOTE: all FT subsets should have "metadata_file" - if all(["metadata_file" in subset for subset in subsets_config]): + elif all(["metadata_file" in subset for subset in subsets_config]): return Schema(self.ft_dataset_schema)(dataset_config) # check dataset meets DB style # NOTE: all DB subsets should have no "metadata_file" @@ -241,13 +286,16 @@ def validate_flex_dataset(dataset_config: dict): self.dataset_schema = validate_flex_dataset elif support_dreambooth: self.dataset_schema = self.db_dataset_schema - else: + elif support_finetuning: self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema self.general_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) @@ -318,7 +366,11 @@ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, ** # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets subsets = dataset_config.get("subsets", []) is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - if is_dreambooth: + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: subset_params_klass = DreamBoothSubsetParams dataset_params_klass = DreamBoothDatasetParams else: @@ -333,7 +385,7 @@ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, ** params = self.generate_params_by_fallbacks(dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) @@ -361,10 +413,13 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value = None): def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_dreambooth: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset else: @@ -379,6 +434,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu info = "" for i, dataset in enumerate(datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ [Dataset {i}] batch_size: {dataset.batch_size} @@ -407,6 +463,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} color_aug: {subset.color_aug} flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} @@ -421,7 +479,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} \n"""), " ") - else: + elif not is_controlnet: info += indent(dedent(f"""\ metadata_file: {subset.metadata_file} \n"""), " ") @@ -479,6 +537,27 @@ def generate(base_dir: Optional[str], is_reg: bool): return subsets_config +def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir) + + return subsets_config + + def load_user_config(file: str) -> dict: file: Path = Path(file) if not file.is_file(): @@ -507,6 +586,7 @@ def load_user_config(file: str) -> dict: parser = argparse.ArgumentParser() parser.add_argument("--support_dreambooth", action="store_true") parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") parser.add_argument("--support_dropout", action="store_true") parser.add_argument("dataset_config") config_args, remain = parser.parse_known_args() @@ -525,7 +605,7 @@ def load_user_config(file: str) -> dict: print("\n[user_config]") print(user_config) - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) sanitized_user_config = sanitizer.sanitize_user_config(user_config) print("\n[sanitized_user_config]") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 8b44874b9..677d1bf46 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -19,20 +19,71 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): noise_scheduler.all_snr = all_snr.to(device) +def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): + # fix beta: zero terminal SNR + print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + + def enforce_zero_terminal_snr(betas): + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + betas = noise_scheduler.betas + betas = enforce_zero_terminal_snr(betas) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + # print("original:", noise_scheduler.betas) + # print("fixed:", betas) + + noise_scheduler.betas = betas + noise_scheduler.alphas = alphas + noise_scheduler.alphas_cumprod = alphas_cumprod + + def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper loss = loss * snr_weight return loss def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + scale = get_snr_scale(timesteps, noise_scheduler) + loss = loss * scale + return loss + + +def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) + # # show debug info + # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + return scale - loss = loss * scale + +def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): + scale = get_snr_scale(timesteps, noise_scheduler) + # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + loss = loss + loss / scale * v_pred_like_loss return loss @@ -51,6 +102,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted action="store_true", help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", ) + parser.add_argument( + "--v_pred_like_loss", + type=float, + default=None, + help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 1dc496ff5..376fdb1e6 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -26,7 +26,7 @@ def upload( repo_id = args.huggingface_repo_id repo_type = args.huggingface_repo_type token = args.huggingface_token - path_in_repo = args.huggingface_path_in_repo + dest_suffix + path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" api = HfApi(token=token) if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): diff --git a/library/hypernetwork.py b/library/hypernetwork.py new file mode 100644 index 000000000..fbd3fb24e --- /dev/null +++ b/library/hypernetwork.py @@ -0,0 +1,223 @@ +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor2_0, + SlicedAttnProcessor, + XFormersAttnProcessor +) + +try: + import xformers.ops +except: + xformers = None + + +loaded_networks = [] + + +def apply_single_hypernetwork( + hypernetwork, hidden_states, encoder_hidden_states +): + context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) + return context_k, context_v + + +def apply_hypernetworks(context_k, context_v, layer=None): + if len(loaded_networks) == 0: + return context_v, context_v + for hypernetwork in loaded_networks: + context_k, context_v = hypernetwork.forward(context_k, context_v) + + context_k = context_k.to(dtype=context_k.dtype) + context_v = context_v.to(dtype=context_k.dtype) + + return context_k, context_v + + + +def xformers_forward( + self: XFormersAttnProcessor, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def sliced_attn_forward( + self: SlicedAttnProcessor, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] if attention_mask is not None else None + ) + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def v2_0_forward( + self: AttnProcessor2_0, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def replace_attentions_for_hypernetwork(): + import diffusers.models.attention_processor + + diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( + xformers_forward + ) + diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( + sliced_attn_forward + ) + diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py new file mode 100644 index 000000000..43accd9f3 --- /dev/null +++ b/library/ipex/__init__.py @@ -0,0 +1,175 @@ +import os +import sys +import contextlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from .hijacks import ipex_hijacks +from .attention import attention_init + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +def ipex_init(): # pylint: disable=too-many-statements + try: + #Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + #Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + #RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + #AMP: + torch.cuda.amp = torch.xpu.amp + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + + #C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 + + #Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.version.cuda = "11.7" + torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] + torch.cuda.get_device_properties.major = 11 + torch.cuda.get_device_properties.minor = 7 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 + if hasattr(torch.xpu, 'getDeviceIdListForCard'): + torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard + torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard + else: + torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card + torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card + + ipex_hijacks() + attention_init() + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + except Exception as e: + return False, e + return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py new file mode 100644 index 000000000..84848b6a6 --- /dev/null +++ b/library/ipex/attention.py @@ -0,0 +1,157 @@ +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +original_torch_bmm = torch.bmm +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] + block_multiply = input.element_size() + slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + if block_size > 4: + do_split = True + #Find something divisible with the input_tokens + while (split_slice_size * slice_block_size) > 4: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_2_slice_size = input_tokens + if split_slice_size * slice_block_size > 4: + slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply + do_split_2 = True + #Find something divisible with the input_tokens + while (split_2_slice_size * slice_block_size2) > 4: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) + else: + hidden_states[start_idx:end_idx] = original_torch_bmm( + input[start_idx:end_idx], + mat2[start_idx:end_idx], + out=out + ) + else: + return original_torch_bmm(input, mat2, out=out) + return hidden_states + +original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + if len(query.shape) == 3: + batch_size_attention, query_tokens, shape_four = query.shape + shape_one = 1 + no_shape_one = True + else: + shape_one, batch_size_attention, query_tokens, shape_four = query.shape + no_shape_one = False + + block_multiply = query.element_size() + slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + if block_size > 4: + do_split = True + #Find something divisible with the shape_one + while (split_slice_size * slice_block_size) > 4: + split_slice_size = split_slice_size // 2 + if split_slice_size <= 1: + split_slice_size = 1 + break + else: + do_split = False + + split_2_slice_size = query_tokens + if split_slice_size * slice_block_size > 4: + slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply + do_split_2 = True + #Find something divisible with the batch_size_attention + while (split_2_slice_size * slice_block_size2) > 4: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + if do_split: + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if no_shape_one: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2], + key[start_idx:end_idx, start_idx_2:end_idx_2], + value[start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx, start_idx_2:end_idx_2], + key[:, start_idx:end_idx, start_idx_2:end_idx_2], + value[:, start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + if no_shape_one: + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( + query[:, start_idx:end_idx], + key[:, start_idx:end_idx], + value[:, start_idx:end_idx], + attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: + return original_scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + ) + return hidden_states + +def attention_init(): + #ARC GPUs can't allocate more than 4GB to a single block: + torch.bmm = torch_bmm + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py new file mode 100644 index 000000000..005ee49f0 --- /dev/null +++ b/library/ipex/diffusers.py @@ -0,0 +1,120 @@ +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import diffusers #0.21.1 # pylint: disable=import-error +from diffusers.models.attention_processor import Attention + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +class SlicedAttnProcessor: # pylint: disable=too-few-public-methods + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, shape_three = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + block_multiply = query.element_size() + slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply + block_size = query_tokens * slice_block_size + split_2_slice_size = query_tokens + if block_size > 4: + do_split_2 = True + #Find something divisible with the query_tokens + while (split_2_slice_size * slice_block_size) > 4: + split_2_slice_size = split_2_slice_size // 2 + if split_2_slice_size <= 1: + split_2_slice_size = 1 + break + else: + do_split_2 = False + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +def ipex_diffusers(): + #ARC GPUs can't allocate more than 4GB to a single block: + diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py new file mode 100644 index 000000000..530212101 --- /dev/null +++ b/library/ipex/gradscaler.py @@ -0,0 +1,179 @@ +from collections import defaultdict +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +OptState = ipex.cpu.autocast._grad_scaler.OptState +_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator +_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state + +def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + # sync grad to master weight + if hasattr(optimizer, "sync_grad"): + optimizer.sync_grad() + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # -: is there a way to split by device and dtype without appending in the inner loop? + to_unscale = to_unscale.to("cpu") + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get("cpu"), + per_device_inv_scale.get("cpu"), + ) + + return per_device_found_inf._per_device_tensors + +def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + +def update(self, new_scale=None): + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device="cpu", non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + to_device = _scale.device + _scale = _scale.to("cpu") + _growth_tracker = _growth_tracker.to("cpu") + + core._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + _scale = _scale.to(to_device) + _growth_tracker = _growth_tracker.to(to_device) + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + +def gradscaler_init(): + torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ + torch.xpu.amp.GradScaler.unscale_ = unscale_ + torch.xpu.amp.GradScaler.update = update + return torch.xpu.amp.GradScaler diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py new file mode 100644 index 000000000..77ed5419a --- /dev/null +++ b/library/ipex/hijacks.py @@ -0,0 +1,196 @@ +import contextlib +import importlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return + +class CondFunc: # pylint: disable=missing-class-docstring + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) + +_utils = torch.utils.data._utils +def _shutdown_workers(self): + if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: + return + if hasattr(self, "_shutdown") and not self._shutdown: + self._shutdown = True + try: + if hasattr(self, '_pin_memory_thread'): + self._pin_memory_thread_done_event.set() + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: # pylint: disable=invalid-name + w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: # pylint: disable=invalid-name + q.cancel_join_thread() + q.close() + finally: + if self._worker_pids_set: + torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: # pylint: disable=invalid-name + if w.is_alive(): + w.terminate() + +class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods + def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument + if isinstance(device_ids, list) and len(device_ids) > 1: + print("IPEX backend doesn't support DataParallel on multiple XPU devices") + return module.to("xpu") + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return contextlib.nullcontext() + +def check_device(device): + return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) + +def return_xpu(device): + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + +def ipex_no_cuda(orig_func, *args, **kwargs): + torch.cuda.is_available = lambda: False + orig_func(*args, **kwargs) + torch.cuda.is_available = torch.xpu.is_available + +original_autocast = torch.autocast +def ipex_autocast(*args, **kwargs): + if len(args) > 0 and args[0] == "cuda": + return original_autocast("xpu", *args[1:], **kwargs) + else: + return original_autocast(*args, **kwargs) + +original_torch_cat = torch.cat +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) + +original_interpolate = torch.nn.functional.interpolate +def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments + if antialias or align_corners is not None: + return_device = tensor.device + return_dtype = tensor.dtype + return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) + else: + return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + +original_linalg_solve = torch.linalg.solve +def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name + if A.device != torch.device("cpu") or B.device != torch.device("cpu"): + return_device = A.device + return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + else: + return original_linalg_solve(A, B, *args, **kwargs) + +def ipex_hijacks(): + CondFunc('torch.Tensor.to', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.Tensor.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.empty', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), + lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) + CondFunc('torch.randn', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.ones', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.zeros', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.linspace', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.batch_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + CondFunc('torch.instance_norm', + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, + weight if weight is not None else torch.ones(input.size()[1], device=input.device), + bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + + #Functions with dtype errors: + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + if not torch.xpu.has_fp64_dtype(): + CondFunc('torch.from_numpy', + lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), + lambda orig_func, ndarray: ndarray.dtype == float) + + #Broken functions when torch.cuda.is_available is True: + CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', + lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), + lambda orig_func, *args, **kwargs: True) + + #Functions that make compile mad with CondFunc: + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + torch.nn.DataParallel = DummyDataParallel + torch.autocast = ipex_autocast + torch.cat = torch_cat + torch.linalg.solve = linalg_solve + torch.nn.functional.interpolate = interpolate + torch.backends.cuda.sdp_kernel = return_null_context diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 58b1171e1..9dce91a76 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -6,7 +6,7 @@ from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -426,6 +426,58 @@ def preprocess_mask(mask, scale_factor=8): return mask +def prepare_controlnet_image( + image: PIL.Image.Image, + width: int, + height: int, + batch_size: int, + num_images_per_prompt: int, + device: torch.device, + dtype: torch.dtype, + do_classifier_free_guidance: bool = False, + guess_mode: bool = False, +): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing @@ -464,10 +516,11 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, - clip_skip: int, + # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + clip_skip: int = 1, ): super().__init__( vae=vae, @@ -707,6 +760,8 @@ def __call__( max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, + controlnet=None, + controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, @@ -767,6 +822,11 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + controlnet (`diffusers.ControlNetModel`, *optional*): + A controlnet model to be used for the inference. If not provided, controlnet will be disabled. + controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet + inference. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. @@ -785,6 +845,9 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + if controlnet is not None and controlnet_image is None: + raise ValueError("controlnet_image must be provided if controlnet is not None.") + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -824,6 +887,11 @@ def __call__( else: mask = None + if controlnet_image is not None: + controlnet_image = prepare_controlnet_image( + controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False + ) + # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) @@ -851,8 +919,22 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + unet_additional_args = {} + if controlnet is not None: + down_block_res_samples, mid_block_res_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=controlnet_image, + conditioning_scale=1.0, + guess_mode=False, + return_dict=False, + ) + unet_additional_args["down_block_additional_residuals"] = down_block_res_samples + unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample # perform guidance if do_classifier_free_guidance: @@ -874,20 +956,13 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + return latents - if not return_dict: - return image, has_nsfw_concept - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + def latents_to_image(self, latents): + # 9. Post-processing + image = self.decode_latents(latents.to(self.vae.dtype)) + image = self.numpy_to_pil(image) + return image def text2img( self, diff --git a/library/model_util.py b/library/model_util.py index 70a8c7523..00a3c0495 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,9 +4,18 @@ import math import os import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file +from library.original_unet import UNet2DConditionModel # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -126,17 +135,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + if diffusers.__version__ < "0.17.0": + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + else: + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -191,8 +213,16 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + reshaping = False + if diffusers.__version__ < "0.17.0": + if "proj_attn.weight" in new_path: + reshaping = True + else: + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = True + + if reshaping: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -361,7 +391,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config): # SDのv2では1*1のconv2dがlinearに変わっている # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 - if v2 and not config.get('use_linear_projection', False): + if v2 and not config.get("use_linear_projection", False): linear_transformer_to_conv(new_checkpoint) return new_checkpoint @@ -540,10 +570,10 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - + # support checkpoint without position_ids (invalid checkpoint) if "text_model.embeddings.position_ids" not in text_model_dict: - text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text return text_model_dict @@ -737,6 +767,105 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): return new_state_dict +def controlnet_conversion_map(): + unet_conversion_map = [ + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("middle_block_out.0.weight", "controlnet_mid_block.weight"), + ("middle_block_out.0.bias", "controlnet_mid_block.bias"), + ] + + unet_conversion_map_resnet = [ + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + for j in range(2): + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] + for i, hf_prefix in enumerate(controlnet_cond_embedding_names): + hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." + sd_prefix = f"input_hint_block.{i*2}." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + for i in range(12): + hf_prefix = f"controlnet_down_blocks.{i}." + sd_prefix = f"zero_convs.{i}.0." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer + + +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[diffusers_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[sd_name] = diffusers_name + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + for k, v in mapping.items(): + if "resnets" in v: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + # ================# # VAE Conversion # # ================# @@ -784,14 +913,24 @@ def convert_vae_state_dict(vae_state_dict): sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] + if diffusers.__version__ < "0.17.0": + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + else: + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), + ] mapping = {k: k for k in vae_state_dict.keys()} for k, v in mapping.items(): @@ -808,7 +947,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") + # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -857,7 +996,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. @@ -905,16 +1044,49 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt else: converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - logging.set_verbosity_error() # don't show annoying warning - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) - logging.set_verbosity_warning() - + # logging.set_verbosity_error() # don't show annoying warning + # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + # logging.set_verbosity_warning() + # print(f"config: {text_model.config}") + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + torch_dtype="float32", + ) + text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) print("loading text encoder:", info) return text_model, vae, unet +def get_model_version_str_for_sd1_sd2(v2, v_parameterization): + # only for reference + version_str = "sd" + if v2: + version_str += "_v2" + else: + version_str += "_v1" + if v_parameterization: + version_str += "_v" + return version_str + + def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): def convert_key(key): # position_idsの除去 @@ -986,7 +1158,9 @@ def convert_key(key): return new_sd -def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): +def save_stable_diffusion_checkpoint( + v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None +): if ckpt_path is not None: # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) @@ -1048,7 +1222,7 @@ def update_sd(prefix, sd): if is_safetensors(output_file): # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) + save_file(state_dict, output_file, metadata) else: torch.save(new_ckpt, output_file) diff --git a/library/original_unet.py b/library/original_unet.py new file mode 100644 index 000000000..c0028ddc2 --- /dev/null +++ b/library/original_unet.py @@ -0,0 +1,1606 @@ +# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる +# 条件分岐等で不要な部分は削除している +# コードの多くはDiffusersからコピーしている +# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある + +# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. +# Unnecessary parts are deleted by condition branching. +# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 + +""" +v1.5とv2.1の相違点は +- attention_head_dimがintかlist[int]か +- cross_attention_dimが768か1024か +- use_linear_projection: trueがない(=False, 1.5)かあるか +- upcast_attentionがFalse(1.5)かTrue(2.1)か +- (以下は多分無視していい) +- sample_sizeが64か96か +- dual_cross_attentionがあるかないか +- num_class_embedsがあるかないか +- only_cross_attentionがあるかないか + +v1.5 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.6.0", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 768, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ] +} + +v2.1 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "sample_size": 96, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} +""" + +import math +from types import SimpleNamespace +from typing import Dict, Optional, Tuple, Union +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange + +BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) +TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] +TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4 +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +LAYERS_PER_BLOCK: int = 2 +LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 +TIME_EMBED_FLIP_SIN_TO_COS: bool = True +TIME_EMBED_FREQ_SHIFT: int = 0 +NORM_GROUPS: int = 32 +NORM_EPS: float = 1e-5 +TRANSFORMER_NORM_NUM_GROUPS = 32 + +DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] +UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] + + +# region memory effcient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class SampleOutput: + def __init__(self, sample): + self.sample = sample + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) + + self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + # if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + + self.use_in_shortcut = self.in_channels != self.out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + pass + + def set_use_sdpa(self, sdpa): + pass + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff + + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + + def forward_sdpa(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + + def forward(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ] + ) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return SampleOutput(sample=output) + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + cross_attention_dim=1280, + attn_num_head_channels=1, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + resnets = [] + attentions = [] + + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + + resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels)) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for attn in self.attentions: + attn.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + use_linear_projection=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + # Middle block has two resnets and one attention + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ] + attentions = [ + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for attn in self.attentions: + attn.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + for i, resnet in enumerate(self.resnets): + attn = None if i == 0 else self.attentions[i - 1] + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + if attn is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + add_upsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + pass + + def set_use_sdpa(self, sdpa): + pass + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + add_upsample=True, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for attn in self.attentions: + attn.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, spda): + for attn in self.attentions: + attn.set_use_sdpa(spda) + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +def get_down_block( + down_block_type, + in_channels, + out_channels, + add_downsample, + attn_num_head_channels, + cross_attention_dim, + use_linear_projection, + upcast_attention, +): + if down_block_type == "DownBlock2D": + return DownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlock2D": + return CrossAttnDownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +def get_up_block( + up_block_type, + in_channels, + out_channels, + prev_output_channel, + add_upsample, + attn_num_head_channels, + cross_attention_dim=None, + use_linear_projection=False, + upcast_attention=False, +): + if up_block_type == "UpBlock2D": + return UpBlock2D( + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlock2D": + return CrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + attn_num_head_channels=attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +class UNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + sample_size: Optional[int] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + cross_attention_dim: int = 1280, + use_linear_projection: bool = False, + upcast_attention: bool = False, + **kwargs, + ): + super().__init__() + assert sample_size is not None, "sample_size must be specified" + print( + f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" + ) + + # 外部からの参照用に定義しておく + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + + self.sample_size = sample_size + self.prepare_config() + + # state_dictの書式が変わるのでmoduleの持ち方は変えられない + + # input + self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT) + + self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * 4 + + # down + output_channel = BLOCK_OUT_CHANNELS[0] + for i, down_block_type in enumerate(DOWN_BLOCK_TYPES): + input_channel = output_channel + output_channel = BLOCK_OUT_CHANNELS[i] + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + down_block = get_down_block( + down_block_type, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + attn_num_head_channels=attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=BLOCK_OUT_CHANNELS[-1], + attn_num_head_channels=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(UP_BLOCK_TYPES): + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=add_upsample, + attn_num_head_channels=reversed_attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) + + # region diffusers compatibility + def prepare_config(self): + self.config = SimpleNamespace() + + @property + def dtype(self) -> torch.dtype: + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + return get_parameter_dtype(self) + + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self.set_gradient_checkpointing(value=True) + + def disable_gradient_checkpointing(self): + self.set_gradient_checkpointing(value=False) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_sdpa(sdpa) + + def set_gradient_checkpointing(self, value=False): + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + module.gradient_checkpointing = value + + # endregion + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) + + def handle_unusual_timesteps(self, sample, timesteps): + r""" + timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。 + """ + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + return timesteps diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py new file mode 100644 index 000000000..472686ba4 --- /dev/null +++ b/library/sai_model_spec.py @@ -0,0 +1,305 @@ +# based on https://github.com/Stability-AI/ModelSpec +import datetime +import hashlib +from io import BytesIO +import os +from typing import List, Optional, Tuple, Union +import safetensors + +r""" +# Metadata Example +metadata = { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID + "modelspec.implementation": "sgm", + "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc + # === Should === + "modelspec.author": "Example Corp", # Your name or company name + "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know + "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created + # === Can === + "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. + "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model +} +""" + +BASE_METADATA = { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": None, + "modelspec.implementation": None, + "modelspec.title": None, + "modelspec.resolution": None, + # === Should === + "modelspec.description": None, + "modelspec.author": None, + "modelspec.date": None, + # === Can === + "modelspec.license": None, + "modelspec.tags": None, + "modelspec.merged_from": None, + "modelspec.prediction_type": None, + "modelspec.timestep_range": None, + "modelspec.encoder_layer": None, +} + +# 別に使うやつだけ定義 +MODELSPEC_TITLE = "modelspec.title" + +ARCH_SD_V1 = "stable-diffusion-v1" +ARCH_SD_V2_512 = "stable-diffusion-v2-512" +ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" +ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" + +ADAPTER_LORA = "lora" +ADAPTER_TEXTUAL_INVERSION = "textual-inversion" + +IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_DIFFUSERS = "diffusers" + +PRED_TYPE_EPSILON = "epsilon" +PRED_TYPE_V = "v" + + +def load_bytes_in_safetensors(tensors): + bytes = safetensors.torch.save(tensors) + b = BytesIO(bytes) + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + + return b.read() + + +def precalculate_safetensors_hashes(state_dict): + # calculate each tensor one by one to reduce memory usage + hash_sha256 = hashlib.sha256() + for tensor in state_dict.values(): + single_tensor_sd = {"tensor": tensor} + bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) + hash_sha256.update(bytes_for_tensor) + + return f"0x{hash_sha256.hexdigest()}" + + +def update_hash_sha256(metadata: dict, state_dict: dict): + raise NotImplementedError + + +def build_metadata( + state_dict: Optional[dict], + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + timestamp: float, + title: Optional[str] = None, + reso: Optional[Union[int, Tuple[int, int]]] = None, + is_stable_diffusion_ckpt: Optional[bool] = None, + author: Optional[str] = None, + description: Optional[str] = None, + license: Optional[str] = None, + tags: Optional[str] = None, + merged_from: Optional[str] = None, + timesteps: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, +): + # if state_dict is None, hash is not calculated + + metadata = {} + metadata.update(BASE_METADATA) + + # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する + # if state_dict is not None: + # hash = precalculate_safetensors_hashes(state_dict) + # metadata["modelspec.hash_sha256"] = hash + + if sdxl: + arch = ARCH_SD_XL_V1_BASE + elif v2: + if v_parameterization: + arch = ARCH_SD_V2_768_V + else: + arch = ARCH_SD_V2_512 + else: + arch = ARCH_SD_V1 + + if lora: + arch += f"/{ADAPTER_LORA}" + elif textual_inversion: + arch += f"/{ADAPTER_TEXTUAL_INVERSION}" + + metadata["modelspec.architecture"] = arch + + if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + + if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + # Stable Diffusion ckpt, TI, SDXL LoRA + impl = IMPL_STABILITY_AI + else: + # v1/v2 LoRA or Diffusers + impl = IMPL_DIFFUSERS + metadata["modelspec.implementation"] = impl + + if title is None: + if lora: + title = "LoRA" + elif textual_inversion: + title = "TextualInversion" + else: + title = "Checkpoint" + title += f"@{timestamp}" + metadata[MODELSPEC_TITLE] = title + + if author is not None: + metadata["modelspec.author"] = author + else: + del metadata["modelspec.author"] + + if description is not None: + metadata["modelspec.description"] = description + else: + del metadata["modelspec.description"] + + if merged_from is not None: + metadata["modelspec.merged_from"] = merged_from + else: + del metadata["modelspec.merged_from"] + + if license is not None: + metadata["modelspec.license"] = license + else: + del metadata["modelspec.license"] + + if tags is not None: + metadata["modelspec.tags"] = tags + else: + del metadata["modelspec.tags"] + + # remove microsecond from time + int_ts = int(timestamp) + + # time to iso-8601 compliant date + date = datetime.datetime.fromtimestamp(int_ts).isoformat() + metadata["modelspec.date"] = date + + if reso is not None: + # comma separated to tuple + if isinstance(reso, str): + reso = tuple(map(int, reso.split(","))) + if len(reso) == 1: + reso = (reso[0], reso[0]) + else: + # resolution is defined in dataset, so use default + if sdxl: + reso = 1024 + elif v2 and v_parameterization: + reso = 768 + else: + reso = 512 + if isinstance(reso, int): + reso = (reso, reso) + + metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" + + if v_parameterization: + metadata["modelspec.prediction_type"] = PRED_TYPE_V + else: + metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON + + if timesteps is not None: + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" + else: + del metadata["modelspec.timestep_range"] + + if clip_skip is not None: + metadata["modelspec.encoder_layer"] = f"{clip_skip}" + else: + del metadata["modelspec.encoder_layer"] + + # # assert all values are filled + # assert all([v is not None for v in metadata.values()]), metadata + if not all([v is not None for v in metadata.values()]): + print(f"Internal error: some metadata values are None: {metadata}") + + return metadata + + +# region utils + + +def get_title(metadata: dict) -> Optional[str]: + return metadata.get(MODELSPEC_TITLE, None) + + +def load_metadata_from_safetensors(model: str) -> dict: + if not model.endswith(".safetensors"): + return {} + + with safetensors.safe_open(model, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + return metadata + + +def build_merged_from(models: List[str]) -> str: + def get_title(model: str): + metadata = load_metadata_from_safetensors(model) + title = metadata.get(MODELSPEC_TITLE, None) + if title is None: + title = os.path.splitext(os.path.basename(model))[0] # use filename + return title + + titles = [get_title(model) for model in models] + return ", ".join(titles) + + +# endregion + + +r""" +if __name__ == "__main__": + import argparse + import torch + from safetensors.torch import load_file + from library import train_util + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, required=True) + args = parser.parse_args() + + print(f"Loading {args.ckpt}") + state_dict = load_file(args.ckpt) + + print(f"Calculating metadata") + metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) + print(metadata) + del state_dict + + # by reference implementation + with open(args.ckpt, mode="rb") as file_data: + file_hash = hashlib.sha256() + head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix + header = json.loads(file_data.read(head_len[0])) # header itself, json string + content = ( + file_data.read() + ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. + file_hash.update(content) + # ===== Update the hash for modelspec ===== + by_ref = f"0x{file_hash.hexdigest()}" + print(by_ref) + print("is same?", by_ref == metadata["modelspec.hash_sha256"]) + +""" diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py new file mode 100644 index 000000000..e03ee4056 --- /dev/null +++ b/library/sdxl_lpw_stable_diffusion.py @@ -0,0 +1,1342 @@ +# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# and modify to support SD2.x + +import inspect +import re +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from tqdm import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.utils import logging +from PIL import Image + +from library import sdxl_model_util, sdxl_train_util, train_util + + +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device): + if not is_sdxl_text_encoder2: + # text_encoder1: same as SD1/2 + enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) + hidden_states = enc_out["hidden_states"][11] + pool = None + else: + # text_encoder2 + enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) + hidden_states = enc_out["hidden_states"][-2] # penuultimate layer + # pool = enc_out["text_embeds"] + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id) + hidden_states = hidden_states.to(device) + if pool is not None: + pool = pool.to(device) + return hidden_states, pool + + +def get_unweighted_text_embeddings( + pipe: StableDiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + is_sdxl_text_encoder2: bool, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + text_pool = None + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + text_embedding, current_text_pool = get_hidden_states( + pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device + ) + if text_pool is None: + text_pool = current_text_pool + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device) + return text_embeddings, text_pool + + +def get_weighted_text_embeddings( + pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + is_sdxl_text_encoder2=False, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`StableDiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + is_sdxl_text_encoder2, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + is_sdxl_text_encoder2, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool + return text_embeddings, text_pool, None, None + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +def prepare_controlnet_image( + image: PIL.Image.Image, + width: int, + height: int, + batch_size: int, + num_images_per_prompt: int, + device: torch.device, + dtype: torch.dtype, + do_classifier_free_guidance: bool = False, + guess_mode: bool = False, +): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + +class SdxlStableDiffusionLongPromptWeightingPipeline: + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: List[CLIPTextModel], + tokenizer: List[CLIPTokenizer], + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + # clip_skip: int, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + clip_skip: int = 1, + ): + # clip skip is ignored currently + self.tokenizer = tokenizer[0] + self.text_encoder = text_encoder[0] + self.unet = unet + self.scheduler = scheduler + self.safety_checker = safety_checker + self.feature_extractor = feature_extractor + self.requires_safety_checker = requires_safety_checker + self.vae = vae + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.progress_bar = lambda x: tqdm(x, leave=False) + + self.clip_skip = clip_skip + self.tokenizers = tokenizer + self.text_encoders = text_encoder + + # self.__init__additional__() + + # def __init__additional__(self): + # if not hasattr(self, "vae_scale_factor"): + # setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) + + def to(self, device=None, dtype=None): + if device is not None: + self.device = device + # self.vae.to(device=self.device) + if dtype is not None: + self.dtype = dtype + + # do not move Text Encoders to device, because Text Encoder should be on CPU + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + is_sdxl_text_encoder2, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + is_sdxl_text_encoder2=is_sdxl_text_encoder2, + ) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + if text_pool is not None: + text_pool = text_pool.repeat(1, num_images_per_prompt) + text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + if uncond_pool is not None: + uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) + uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) + + return text_embeddings, text_pool, uncond_embeddings, uncond_pool + + return text_embeddings, text_pool, None, None + + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): + if is_text2img: + return self.scheduler.timesteps.to(device), num_inference_steps + else: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + with torch.no_grad(): + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + + # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32 + # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0) + # print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16 + # self.vae.to("cpu") + # self.vae.set_use_memory_efficient_attention_xformers(False) + # image = self.vae.decode(latents.to("cpu")).sample + + image = self.vae.decode(latents.to(self.vae.dtype)).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): + if image is None: + shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, None, None + else: + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents_orig = init_latents + shape = init_latents.shape + + # add noise to latents using the timesteps + if device.type == "mps": + noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + controlnet=None, + controlnet_image=None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + controlnet (`diffusers.ControlNetModel`, *optional*): + A controlnet model to be used for the inference. If not provided, controlnet will be disabled. + controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet + inference. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if controlnet is not None and controlnet_image is None: + raise ValueError("controlnet_image must be provided if controlnet is not None.") + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す + # To simplify the implementation, switch the tokenzer/text encoder and call it twice + text_embeddings_list = [] + text_pool = None + uncond_embeddings_list = [] + uncond_pool = None + for i in range(len(self.tokenizers)): + self.tokenizer = self.tokenizers[i] + self.text_encoder = self.text_encoders[i] + + text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + is_sdxl_text_encoder2=i == 1, + ) + text_embeddings_list.append(text_embeddings) + uncond_embeddings_list.append(uncond_embeddings) + + if tp1 is not None: + text_pool = tp1 + if up1 is not None: + uncond_pool = up1 + + dtype = self.unet.dtype + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.to(device=self.device, dtype=dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=dtype) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) + else: + mask = None + + # ControlNet is not working yet in SDXL, but keep the code here for future use + if controlnet_image is not None: + controlnet_image = prepare_controlnet_image( + controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False + ) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # create size embs and concat embeddings for SDXL + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) + crop_size = torch.zeros_like(orig_size) + target_size = orig_size + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) + + # make conditionings + if do_classifier_free_guidance: + text_embeddings = torch.cat(text_embeddings_list, dim=2) + uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) + + cond_vector = torch.cat([text_pool, embs], dim=1) + uncond_vector = torch.cat([uncond_pool, embs], dim=1) + vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) + else: + text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + unet_additional_args = {} + if controlnet is not None: + down_block_res_samples, mid_block_res_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=controlnet_image, + conditioning_scale=1.0, + guess_mode=False, + return_dict=False, + ) + unet_additional_args["down_block_additional_residuals"] = down_block_res_samples + unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + return latents + + def latents_to_image(self, latents): + # 9. Post-processing + image = self.decode_latents(latents.to(self.vae.dtype)) + image = self.numpy_to_pil(image) + return image + + # copy from pil_utils.py + def numpy_to_pil(self, images: np.ndarray) -> Image.Image: + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def img2img( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for image-to-image generation. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def inpaint( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for inpaint. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py new file mode 100644 index 000000000..2f0154cae --- /dev/null +++ b/library/sdxl_model_util.py @@ -0,0 +1,572 @@ +import torch +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device +from safetensors.torch import load_file, save_file +from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from typing import List +from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel +from library import model_util +from library import sdxl_original_unet + + +VAE_SCALE_FACTOR = 0.13025 +MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" + +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" + +DIFFUSERS_SDXL_UNET_CONFIG = { + "act_fn": "silu", + "addition_embed_type": "text_time", + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": 256, + "attention_head_dim": [5, 10, 20], + "block_out_channels": [320, 640, 1280], + "center_input_sample": False, + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 2048, + "cross_attention_norm": None, + "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "encoder_hid_dim": None, + "encoder_hid_dim_type": None, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": None, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": None, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "projection_class_embeddings_input_dim": 2816, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, + "resnet_time_scale_shift": "default", + "sample_size": 128, + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "transformer_layers_per_block": [1, 2, 10], + "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], + "upcast_attention": False, + "use_linear_projection": True, +} + + +def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): + SDXL_KEY_PREFIX = "conditioner.embedders.1.model." + + # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す + # logit_scaleはcheckpointの保存時に使用する + def convert_key(key): + # common conversion + key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") + key = key.replace(SDXL_KEY_PREFIX, "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = key.replace("text_model.text_projection", "text_projection.weight") + elif ".logit_scale" in key: + key = None # 後で処理する + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids + elif ".embeddings.position_ids" in key: + key = None # remove this key: make position_ids by ourselves + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # original SD にはないので、position_idsを追加 + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + new_sd["text_model.embeddings.position_ids"] = position_ids + + # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す + logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + + return new_sd, logit_scale + + +# load state_dict without allocating new tensors +def _load_state_dict_on_device(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return "" + + # error_msgs + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) + if unexpected_keys: + error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) + + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) + + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): + # model_version is reserved for future use + # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching + + # Load the state dict + if model_util.is_safetensors(ckpt_path): + checkpoint = None + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error + epoch = None + global_step = None + else: + checkpoint = torch.load(ckpt_path, map_location=map_location) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + epoch = checkpoint.get("epoch", 0) + global_step = checkpoint.get("global_step", 0) + else: + state_dict = checkpoint + epoch = 0 + global_step = 0 + checkpoint = None + + # U-Net + print("building U-Net") + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + + print("loading U-Net from checkpoint") + unet_sd = {} + for k in list(state_dict.keys()): + if k.startswith("model.diffusion_model."): + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) + print("U-Net: ", info) + + # Text Encoders + print("building text encoders") + + # Text Encoder 1 is same to Stability AI's SDXL + text_model1_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + text_model1 = CLIPTextModel._from_config(text_model1_cfg) + + # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) + + print("loading text encoders from checkpoint") + te1_sd = {} + te2_sd = {} + for k in list(state_dict.keys()): + if k.startswith("conditioner.embedders.0.transformer."): + te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) + elif k.startswith("conditioner.embedders.1.model."): + te2_sd[k] = state_dict.pop(k) + + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models + if "text_model.embeddings.position_ids" not in te1_sd: + te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) + + info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 + print("text encoder 1:", info1) + + converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 + print("text encoder 2:", info2) + + # prepare vae + print("building VAE") + vae_config = model_util.create_vae_diffusers_config() + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) + + print("loading VAE from checkpoint") + converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) + info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) + print("VAE:", info) + + ckpt_info = (epoch, global_step) if epoch is not None else None + return text_model1, text_model2, vae, unet, logit_scale, ckpt_info + + +def make_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +def convert_diffusers_unet_state_dict_to_sdxl(du_sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_map = {hf: sd for sd, hf in unet_conversion_map} + return convert_unet_state_dict(du_sd, conversion_map) + + +def convert_unet_state_dict(src_sd, conversion_map): + converted_sd = {} + for src_key, value in src_sd.items(): + # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す + src_key_fragments = src_key.split(".")[:-1] # remove weight/bias + while len(src_key_fragments) > 0: + src_key_prefix = ".".join(src_key_fragments) + "." + if src_key_prefix in conversion_map: + converted_prefix = conversion_map[src_key_prefix] + converted_key = converted_prefix + src_key[len(src_key_prefix) :] + converted_sd[converted_key] = value + break + src_key_fragments.pop(-1) + assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" + + return converted_sd + + +def convert_sdxl_unet_state_dict_to_diffusers(sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_dict = {sd: hf for sd, hf in unet_conversion_map} + return convert_unet_state_dict(sd, conversion_dict) + + +def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "text_projection" in key: # no dot in key + key = key.replace("text_projection.weight", "text_projection") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + if logit_scale is not None: + new_sd["logit_scale"] = logit_scale + + return new_sd + + +def save_stable_diffusion_checkpoint( + output_file, + text_encoder1, + text_encoder2, + unet, + epochs, + steps, + ckpt_info, + vae, + logit_scale, + metadata, + save_dtype=None, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + update_sd("model.diffusion_model.", unet.state_dict()) + + # Convert the text encoders + update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) + + text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale) + update_sd("conditioner.embedders.1.model.", text_enc2_dict) + + # Convert the VAE + vae_dict = model_util.convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + if ckpt_info is not None: + epochs += ckpt_info[0] + steps += ckpt_info[1] + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if model_util.is_safetensors(output_file): + save_file(state_dict, output_file, metadata) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint( + output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None +): + from diffusers import StableDiffusionXLPipeline + + # convert U-Net + unet_sd = unet.state_dict() + du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) + + diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) + if save_dtype is not None: + diffusers_unet.to(save_dtype) + diffusers_unet.load_state_dict(du_unet_sd) + + # create pipeline to save + if pretrained_model_name_or_path is None: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL + + scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + # prevent local path from being saved + def remove_name_or_path(model): + if hasattr(model, "config"): + model.config._name_or_path = None + model.config._name_or_path = None + + remove_name_or_path(diffusers_unet) + remove_name_or_path(text_encoder1) + remove_name_or_path(text_encoder2) + remove_name_or_path(scheduler) + remove_name_or_path(tokenizer1) + remove_name_or_path(tokenizer2) + remove_name_or_path(vae) + + pipeline = StableDiffusionXLPipeline( + unet=diffusers_unet, + text_encoder=text_encoder1, + text_encoder_2=text_encoder2, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer1, + tokenizer_2=tokenizer2, + ) + if save_dtype is not None: + pipeline.to(None, save_dtype) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py new file mode 100644 index 000000000..586909bdb --- /dev/null +++ b/library/sdxl_original_unet.py @@ -0,0 +1,1148 @@ +# Diffusersのコードをベースとした sd_xl_baseのU-Net +# state dictの形式をSDXLに合わせてある + +""" + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False +""" + +import math +from types import SimpleNamespace +from typing import Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange + + +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +ADM_IN_CHANNELS: int = 2816 +CONTEXT_DIM: int = 2048 +MODEL_CHANNELS: int = 320 +TIME_EMBED_DIM = 320 * 4 + +USE_REENTRANT = True + +# region memory effcient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + if self.weight.dtype != torch.float32: + return super().forward(x) + return super().forward(x.float()).type(x.dtype) + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.in_layers = nn.Sequential( + GroupNorm32(32, in_channels), + nn.SiLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels)) + + self.out_layers = nn.Sequential( + GroupNorm32(32, out_channels), + nn.SiLU(), + nn.Identity(), # to make state_dict compatible with original model + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + if in_channels != out_channels: + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.skip_connection = nn.Identity() + + self.gradient_checkpointing = False + + def forward_body(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + h = h + emb_out[:, :, None, None] + h = self.out_layers(h) + x = self.skip_connection(x) + return x + h + + def forward(self, x, emb): + if self.training and self.gradient_checkpointing: + # print("ResnetBlock2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT) + else: + x = self.forward_body(x, emb) + + return x + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + self.gradient_checkpointing = False + + def forward_body(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.op(hidden_states) + + return hidden_states + + def forward(self, hidden_states): + if self.training and self.gradient_checkpointing: + # print("Downsample2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT + ) + else: + hidden_states = self.forward_body(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff + + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + del q, k, v + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + + def forward_sdpa(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + self.gradient_checkpointing = False + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + + def forward_body(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + def forward(self, hidden_states, context=None, timestep=None): + if self.training and self.gradient_checkpointing: + # print("BasicTransformerBlock: checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + output = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT + ) + else: + output = self.forward_body(hidden_states, context, timestep) + + return output + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + num_transformer_layers: int = 1, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + blocks = [] + for _ in range(num_transformer_layers): + blocks.append( + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ) + + self.transformer_blocks = nn.ModuleList(blocks) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + return output + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward_body(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + def forward(self, hidden_states, output_size=None): + if self.training and self.gradient_checkpointing: + # print("Upsample2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT + ) + else: + hidden_states = self.forward_body(hidden_states, output_size) + + return hidden_states + + +class SdxlUNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + self.model_channels = MODEL_CHANNELS + self.time_embed_dim = TIME_EMBED_DIM + self.adm_in_channels = ADM_IN_CHANNELS + + self.gradient_checkpointing = False + # self.sample_size = sample_size + + # time embedding + self.time_embed = nn.Sequential( + nn.Linear(self.model_channels, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + + # label embedding + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(self.adm_in_channels, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + ) + + # input + self.input_blocks = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)), + ) + ] + ) + + # level 0 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=1 * self.model_channels, + out_channels=1 * self.model_channels, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + self.input_blocks.append( + nn.Sequential( + Downsample2D( + channels=1 * self.model_channels, + out_channels=1 * self.model_channels, + ), + ) + ) + + # level 1 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=(1 if i == 0 else 2) * self.model_channels, + out_channels=2 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=2 * self.model_channels // 64, + attention_head_dim=64, + in_channels=2 * self.model_channels, + num_transformer_layers=2, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + self.input_blocks.append( + nn.Sequential( + Downsample2D( + channels=2 * self.model_channels, + out_channels=2 * self.model_channels, + ), + ) + ) + + # level 2 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=(2 if i == 0 else 4) * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + # mid + self.middle_block = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ResnetBlock2D( + in_channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ), + ] + ) + + # output + self.output_blocks = nn.ModuleList([]) + + # level 2 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + if i == 2: + layers.append( + Upsample2D( + channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ) + ) + + self.output_blocks.append(nn.ModuleList(layers)) + + # level 1 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels, + out_channels=2 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=2 * self.model_channels // 64, + attention_head_dim=64, + in_channels=2 * self.model_channels, + num_transformer_layers=2, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + if i == 2: + layers.append( + Upsample2D( + channels=2 * self.model_channels, + out_channels=2 * self.model_channels, + ) + ) + + self.output_blocks.append(nn.ModuleList(layers)) + + # level 0 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels, + out_channels=1 * self.model_channels, + ), + ] + + self.output_blocks.append(nn.ModuleList(layers)) + + # output + self.out = nn.ModuleList( + [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] + ) + + # region diffusers compatibility + def prepare_config(self): + self.config = SimpleNamespace() + + @property + def dtype(self) -> torch.dtype: + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + return get_parameter_dtype(self) + + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + self.set_gradient_checkpointing(value=True) + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.set_gradient_checkpointing(value=False) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block: + if hasattr(module, "set_use_memory_efficient_attention"): + # print(module.__class__.__name__) + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block: + if hasattr(module, "set_use_sdpa"): + module.set_use_sdpa(sdpa) + + def set_gradient_checkpointing(self, value=False): + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block.modules(): + if hasattr(module, "gradient_checkpointing"): + # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + module.gradient_checkpointing = value + + # endregion + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == self.dtype + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + print("create unet") + unet = SdxlUNet2DConditionModel() + + unet.to("cuda") + unet.set_use_memory_efficient_attention(True, False) + unet.set_gradient_checkpointing(True) + unet.train() + + # 使用メモリ量確認用の疑似学習ループ + print("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + # import bitsandbytes + # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + import transformers + + optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + print("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + print(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda") + ctx = torch.randn(batch_size, 77, 2048).cuda() + y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda() + + with torch.cuda.amp.autocast(enabled=True): + output = unet(x, t, ctx, y) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py new file mode 100644 index 000000000..f637d9931 --- /dev/null +++ b/library/sdxl_train_util.py @@ -0,0 +1,369 @@ +import argparse +import gc +import math +import os +from typing import Optional +import torch +from accelerate import init_empty_weights +from tqdm import tqdm +from transformers import CLIPTokenizer +from library import model_util, sdxl_model_util, train_util, sdxl_original_unet +from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# DEFAULT_NOISE_OFFSET = 0.0357 + + +def load_target_model(args, accelerator, model_version: str, weight_dtype): + # load models for each process + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = _load_target_model( + args.pretrained_model_name_or_path, + args.vae, + model_version, + weight_dtype, + accelerator.device if args.lowram else "cpu", + model_dtype, + ) + + # work on low-ram device + if args.lowram: + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + gc.collect() + torch.cuda.empty_cache() + accelerator.wait_for_everyone() + + text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) + + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info + + +def _load_target_model( + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None +): + # model_dtype only work with full fp16/bf16 + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + + if load_stable_diffusion_format: + print(f"load StableDiffusion checkpoint: {name_or_path}") + ( + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) + else: + # Diffusers model is loaded to CPU + from diffusers import StableDiffusionXLPipeline + + variant = "fp16" if weight_dtype == torch.float16 else None + print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + try: + try: + pipe = StableDiffusionXLPipeline.from_pretrained( + name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None + ) + except EnvironmentError as ex: + if variant is not None: + print("try to load fp32 model") + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) + else: + raise ex + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" + ) + raise ex + + text_encoder1 = pipe.text_encoder + text_encoder2 = pipe.text_encoder_2 + + # convert to fp32 for cache text_encoders outputs + if text_encoder1.dtype != torch.float32: + text_encoder1 = text_encoder1.to(dtype=torch.float32) + if text_encoder2.dtype != torch.float32: + text_encoder2 = text_encoder2.to(dtype=torch.float32) + + vae = pipe.vae + unet = pipe.unet + del pipe + + # Diffusers U-Net to original U-Net + state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet + sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) + print("U-Net converted to original U-Net") + + logit_scale = None + ckpt_info = None + + # VAEを読み込む + if vae_path is not None: + vae = model_util.load_vae(vae_path, weight_dtype) + print("additional VAE loaded") + + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info + + +def load_tokenizers(args: argparse.Namespace): + print("prepare tokenizers") + + original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] + tokeniers = [] + for i, original_path in enumerate(original_paths): + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) + + if tokenizer is None: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + if i == 1: + tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer + + tokeniers.append(tokenizer) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + return tokeniers + + +def match_mixed_precision(args, weight_dtype): + if args.full_fp16: + assert ( + weight_dtype == torch.float16 + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + return weight_dtype + elif args.full_bf16: + assert ( + weight_dtype == torch.bfloat16 + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + return weight_dtype + else: + return None + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_timestep_embedding(x, outdim): + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + x = torch.flatten(x) + emb = timestep_embedding(x, outdim) + emb = torch.reshape(emb, (b, dims * outdim)) + return emb + + +def get_size_embeddings(orig_size, crop_size, target_size, device): + emb1 = get_timestep_embedding(orig_size, 256) + emb2 = get_timestep_embedding(crop_size, 256) + emb3 = get_timestep_embedding(target_size, 256) + vector = torch.cat([emb1, emb2, emb3], dim=1).to(device) + return vector + + +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True) + sdxl_model_util.save_stable_diffusion_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + logit_scale, + sai_metadata, + save_dtype, + ) + + def diffusers_saver(out_dir): + sdxl_model_util.save_diffusers_checkpoint( + out_dir, + text_encoder1, + text_encoder2, + unet, + src_path, + vae, + use_safetensors=use_safetensors, + save_dtype=save_dtype, + ) + + train_util.save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver + ) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + src_path, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True) + sdxl_model_util.save_stable_diffusion_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + logit_scale, + sai_metadata, + save_dtype, + ) + + def diffusers_saver(out_dir): + sdxl_model_util.save_diffusers_checkpoint( + out_dir, + text_encoder1, + text_encoder2, + unet, + src_path, + vae, + use_safetensors=use_safetensors, + save_dtype=save_dtype, + ) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + sd_saver, + diffusers_saver, + ) + + +def add_sdxl_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # print( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # print( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + print( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +def sample_images(*args, **kwargs): + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 490b5a75d..31b2bd0a4 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -22,10 +22,10 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin -from diffusers.utils import BaseOutput -from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, ResnetBlock2D -from diffusers.models.vae import DecoderOutput, Encoder, AutoencoderKLOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.autoencoder_kl import AutoencoderKLOutput def slice_h(x, num_slices): @@ -209,7 +209,7 @@ def __init__( downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - attn_num_head_channels=None, + attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) @@ -221,7 +221,7 @@ def __init__( resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", - attn_num_head_channels=None, + attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, ) @@ -381,7 +381,7 @@ def __init__( resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", - attn_num_head_channels=None, + attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, ) @@ -406,7 +406,7 @@ def __init__( resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - attn_num_head_channels=None, + attention_head_dim=output_channel, temb_channels=None, ) self.up_blocks.append(up_block) diff --git a/library/train_util.py b/library/train_util.py index d1405643c..35bfb5f5b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -34,9 +34,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -51,18 +50,23 @@ HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, + AutoencoderKL, ) +from library import custom_train_functions +from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download -import albumentations as albu import numpy as np from PIL import Image import cv2 -from einops import rearrange -from torch import einsum import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util +import library.sai_model_spec as sai_model_spec + +# from library.attention_processors import FlashAttnProcessor +# from library.hypernetwork import replace_attentions_for_hypernetwork +from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -85,6 +89,29 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] +try: + import pillow_avif + + IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) +except: + pass + +try: + from jxlpy import JXLImagePlugin + + IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) +except: + pass + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -99,7 +126,15 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.latents: torch.Tensor = None self.latents_flipped: torch.Tensor = None self.latents_npz: str = None - self.latents_npz_flipped: str = None + self.latents_original_size: Tuple[int, int] = None # original image size, not latents size + self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size + self.cond_img_path: str = None + self.image: Optional[Image.Image] = None # optional, original PIL Image + # SDXL, optional + self.text_encoder_outputs_npz: Optional[str] = None + self.text_encoder_outputs1: Optional[torch.Tensor] = None + self.text_encoder_outputs2: Optional[torch.Tensor] = None + self.text_encoder_pool2: Optional[torch.Tensor] = None class BucketManager: @@ -117,11 +152,11 @@ def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None self.resos = [] self.reso_to_id = {} - self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key + self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key - def add_image(self, reso, image): + def add_image(self, reso, image_or_info): bucket_id = self.reso_to_id[reso] - self.buckets[bucket_id].append(image) + self.buckets[bucket_id].append(image_or_info) def shuffle(self): for bucket in self.buckets: @@ -168,6 +203,7 @@ def round_to_steps(self, x): def select_bucket(self, image_width, image_height): aspect_ratio = image_width / image_height if not self.no_upscale: + # 拡大および縮小を行う # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する reso = (image_width, image_height) if reso in self.predefined_resos_set: @@ -186,6 +222,7 @@ def select_bucket(self, image_width, image_height): resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) # print("use predef", image_width, image_height, reso, resized_size) else: + # 縮小のみを行う if image_width * image_height > self.max_area: # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める resized_width = math.sqrt(self.max_area * aspect_ratio) @@ -225,6 +262,26 @@ def select_bucket(self, image_width, image_height): ar_error = (reso[0] / reso[1]) - aspect_ratio return reso, resized_size, ar_error + @staticmethod + def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): + # Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める + # Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation. + + bucket_ar = bucket_reso[0] / bucket_reso[1] + image_ar = image_size[0] / image_size[1] + if bucket_ar > image_ar: + # bucketのほうが横長→縦を合わせる + resized_width = bucket_reso[1] * image_ar + resized_height = bucket_reso[1] + else: + resized_width = bucket_reso[0] + resized_height = bucket_reso[0] / image_ar + crop_left = (bucket_reso[0] - resized_width) // 2 + crop_top = (bucket_reso[1] - resized_height) // 2 + crop_right = crop_left + resized_width + crop_bottom = crop_top + resized_height + return crop_left, crop_top, crop_right, crop_bottom + class BucketBatchIndex(NamedTuple): bucket_index: int @@ -233,43 +290,40 @@ class BucketBatchIndex(NamedTuple): class AugHelper: + # albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる + def __init__(self): - # prepare all possible augmentators - color_aug_method = albu.OneOf( - [ - albu.HueSaturationValue(8, 0, 0, p=0.5), - albu.RandomGamma((95, 105), p=0.5), - ], - p=0.33, - ) - flip_aug_method = albu.HorizontalFlip(p=0.5) - - # key: (use_color_aug, use_flip_aug) - self.augmentors = { - (True, True): albu.Compose( - [ - color_aug_method, - flip_aug_method, - ], - p=1.0, - ), - (True, False): albu.Compose( - [ - color_aug_method, - ], - p=1.0, - ), - (False, True): albu.Compose( - [ - flip_aug_method, - ], - p=1.0, - ), - (False, False): None, - } - - def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: - return self.augmentors[(use_color_aug, use_flip_aug)] + pass + + def color_aug(self, image: np.ndarray): + # self.color_aug_method = albu.OneOf( + # [ + # albu.HueSaturationValue(8, 0, 0, p=0.5), + # albu.RandomGamma((95, 105), p=0.5), + # ], + # p=0.33, + # ) + hue_shift_limit = 8 + + # remove dependency to albumentations + if random.random() <= 0.33: + if random.random() > 0.5: + # hue shift + hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit) + if hue_shift < 0: + hue_shift = 180 + hue_shift + hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180 + image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) + else: + # random gamma + gamma = random.uniform(0.95, 1.05) + image = np.clip(image**gamma, 0, 255).astype(np.uint8) + + return {"image": image} + + def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]: + return self.color_aug if use_color_aug else None class BaseSubset: @@ -286,6 +340,8 @@ def __init__( caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + caption_prefix: Optional[str], + caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], ) -> None: @@ -300,6 +356,8 @@ def __init__( self.caption_dropout_rate = caption_dropout_rate self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.caption_prefix = caption_prefix + self.caption_suffix = caption_suffix self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる @@ -324,6 +382,8 @@ def __init__( caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) -> None: @@ -341,6 +401,8 @@ def __init__( caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) @@ -372,6 +434,8 @@ def __init__( caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) -> None: @@ -389,6 +453,8 @@ def __init__( caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + caption_prefix, + caption_suffix, token_warmup_min, token_warmup_step, ) @@ -401,12 +467,70 @@ def __eq__(self, other) -> bool: return self.metadata_file == other.metadata_file +class ControlNetSubset(BaseSubset): + def __init__( + self, + image_dir: str, + conditioning_data_dir: str, + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + ) + + self.conditioning_data_dir = conditioning_data_dir + self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension + + def __eq__(self, other) -> bool: + if not isinstance(other, ControlNetSubset): + return NotImplemented + return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir + + class BaseDataset(torch.utils.data.Dataset): def __init__( - self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool + self, + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], + max_token_length: int, + resolution: Optional[Tuple[int, int]], + debug_dataset: bool, ) -> None: super().__init__() - self.tokenizer = tokenizer + + self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution @@ -427,7 +551,7 @@ def __init__( self.bucket_no_upscale = None self.bucket_info = None # for metadata - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ @@ -438,21 +562,22 @@ def __init__( # augmentation self.aug_helper = AugHelper() - self.image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + self.image_transforms = IMAGE_TRANSFORMS self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} self.replacements = {} + # caching + self.caching_mode = None # None, 'latents', 'text' + def set_seed(self, seed): self.seed = seed + def set_caching_mode(self, mode): + self.caching_mode = mode + def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする self.shuffle_buckets() @@ -486,6 +611,12 @@ def add_replacement(self, str_from, str_to): self.replacements[str_from] = str_to def process_caption(self, subset: BaseSubset, caption): + # caption に prefix/suffix を付ける + if subset.caption_prefix: + caption = subset.caption_prefix + " " + caption + if subset.caption_suffix: + caption = caption + " " + subset.caption_suffix + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate is_drop_out = ( @@ -543,48 +674,49 @@ def dropout_tags(tokens): return caption - def get_input_ids(self, caption): - input_ids = self.tokenizer( + def get_input_ids(self, caption, tokenizer=None): + if tokenizer is None: + tokenizer = self.tokenizers[0] + + input_ids = tokenizer( caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" ).input_ids - if self.tokenizer_max_length > self.tokenizer.model_max_length: + if self.tokenizer_max_length > tokenizer.model_max_length: input_ids = input_ids.squeeze(0) iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + if tokenizer.pad_token_id == tokenizer.eos_token_id: # v1 # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に for i in range( - 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + 1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2 ): # (1, 152, 75) ids_chunk = ( input_ids[0].unsqueeze(0), - input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[i : i + tokenizer.model_max_length - 2], input_ids[-1].unsqueeze(0), ) ids_chunk = torch.cat(ids_chunk) iids_list.append(ids_chunk) else: - # v2 + # v2 or SDXL # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range( - 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 - ): + for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): ids_chunk = ( input_ids[0].unsqueeze(0), # BOS - input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[i : i + tokenizer.model_max_length - 2], input_ids[-1].unsqueeze(0), ) # PAD or EOS ids_chunk = torch.cat(ids_chunk) # 末尾が または の場合は、何もしなくてよい # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id iids_list.append(ids_chunk) @@ -697,42 +829,30 @@ def shuffle_buckets(self): random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): - image_height, image_width = image.shape[0:2] - - if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - - image_height, image_width = image.shape[0:2] - if image_width > reso[0]: - trim_size = image_width - reso[0] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) - image = image[:, p : p + reso[0]] - if image_height > reso[1]: - trim_size = image_height - reso[1] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) - image = image[p : p + reso[1]] - - assert ( - image.shape[0] == reso[1] and image.shape[1] == reso[0] - ), f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image + def verify_bucket_reso_steps(self, min_steps: int): + assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, ( + f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n" + + f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります" + ) def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + def is_text_encoder_output_cacheable(self): + return all( + [ + not ( + subset.caption_dropout_rate > 0 + or subset.shuffle_caption + or subset.token_warmup_step > 0 + or subset.caption_tag_dropout_rate > 0 + ) + for subset in self.subsets + ] + ) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - # ちょっと速くした + # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと print("caching latents.") image_infos = list(self.image_data.values()) @@ -743,42 +863,22 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # split by resolution batches = [] batch = [] - for info in image_infos: + print("checking cache validity...") + for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] - if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - - # might be None, but that's ok because check is done in dataset - info.latents_flipped = self.load_latents_from_npz(info, True) - if info.latents_flipped is not None: - info.latents_flipped = torch.FloatTensor(info.latents_flipped) + if info.latents_npz is not None: # fine tuning dataset continue # check disk cache exists and size of latents if cache_to_disk: - # TODO: refactor to unify with FineTuningDataset info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" - info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" - if not is_main_process: + if not is_main_process: # store to info only continue - cache_available = False - expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 - if os.path.exists(info.latents_npz): - cached_latents = np.load(info.latents_npz)["arr_0"] - if cached_latents.shape[1:3] == expected_latents_size: - cache_available = True - - if subset.flip_aug: - cache_available = False - if os.path.exists(info.latents_npz_flipped): - cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"] - if cached_latents_flipped.shape[1:3] == expected_latents_size: - cache_available = True - - if cache_available: + cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + + if cache_available: # do not add to batch continue # if last member of batch has different resolution, flush the batch @@ -796,44 +896,83 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc if len(batch) > 0: batches.append(batch) - if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return - # iterate batches + # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded + print("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - images = [] - for info in batch: - image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) - image = self.image_transforms(image) - images.append(image) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + + # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる + # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する + # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し + def cache_text_encoder_outputs( + self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + ): + assert len(tokenizers) == 2, "only support SDXL" + + # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する + # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + print("caching text encoder outputs.") + image_infos = list(self.image_data.values()) - img_tensors = torch.stack(images, dim=0) - img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + print("checking cache existence...") + image_infos_to_cache = [] + for info in tqdm(image_infos): + # subset = self.image_to_subset[info.image_key] + if cache_to_disk: + te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + info.text_encoder_outputs_npz = te_out_npz - latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if not is_main_process: # store to info only + continue - for info, latent in zip(batch, latents): - if cache_to_disk: - np.savez(info.latents_npz, latent.float().numpy()) - else: - info.latents = latent - - if subset.flip_aug: - img_tensors = torch.flip(img_tensors, dims=[3]) - latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - for info, latent in zip(batch, latents): - if cache_to_disk: - np.savez(info.latents_npz_flipped, latent.float().numpy()) - else: - info.latents_flipped = latent + if os.path.exists(te_out_npz): + continue + + image_infos_to_cache.append(info) + + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only + return + + # prepare tokenizers and text encoders + for text_encoder in text_encoders: + text_encoder.to(device) + if weight_dtype is not None: + text_encoder.to(dtype=weight_dtype) + + # create batch + batch = [] + batches = [] + for info in image_infos_to_cache: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + + if len(batch) >= self.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches: call text encoder and cache outputs for memory or disk + print("caching text encoder outputs...") + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype + ) def get_image_size(self, image_path): image = Image.open(image_path) return image.size def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = self.load_image(image_path) + img = load_image(image_path) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -894,12 +1033,6 @@ def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_ return image - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - if npz_file is None: - return None - return np.load(npz_file)["arr_0"] - def __len__(self): return self._length @@ -908,24 +1041,47 @@ def __getitem__(self, index): bucket_batch_size = self.buckets_indices[index].bucket_batch_size image_index = self.buckets_indices[index].batch_index * bucket_batch_size + if self.caching_mode is not None: # return batch for latents/text encoder outputs caching + return self.get_item_for_caching(bucket, bucket_batch_size, image_index) + loss_weights = [] captions = [] input_ids_list = [] + input_ids2_list = [] latents_list = [] images = [] + original_sizes_hw = [] + crop_top_lefts = [] + target_sizes_hw = [] + flippeds = [] # 変数名が微妙 + text_encoder_outputs1_list = [] + text_encoder_outputs2_list = [] + text_encoder_pool2_list = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance + # image/latentsを処理する if image_info.latents is not None: # cache_latents=Trueの場合 - latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + original_size = image_info.latents_original_size + crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped + if not flipped: + latents = image_info.latents + else: + latents = image_info.latents_flipped + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + if flipped: + latents = flipped_latents + del flipped_latents latents = torch.FloatTensor(latents) + image = None else: # 画像を読み込み、必要ならcropする @@ -933,7 +1089,9 @@ def __getitem__(self, index): im_h, im_w = img.shape[0:2] if self.enable_bucket: - img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + img, original_size, crop_ltrb = trim_and_resize_if_required( + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + ) else: if face_cx > 0: # 顔位置情報あり img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) @@ -953,44 +1111,106 @@ def __getitem__(self, index): im_h == self.height and im_w == self.width ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + original_size = [im_w, im_h] + crop_ltrb = (0, 0, 0, 0) + # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: img = aug(image=img)["image"] + if flipped: + img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる images.append(image) latents_list.append(latents) - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + flippeds.append(flipped) + + # captionとtext encoder outputを処理する + caption = image_info.caption # default + if image_info.text_encoder_outputs1 is not None: + text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) + text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) + text_encoder_pool2_list.append(image_info.text_encoder_pool2) captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future + elif image_info.text_encoder_outputs_npz is not None: + text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( + image_info.text_encoder_outputs_npz + ) + text_encoder_outputs1_list.append(text_encoder_outputs1) + text_encoder_outputs2_list.append(text_encoder_outputs2) + text_encoder_pool2_list.append(text_encoder_pool2) + captions.append(caption) + else: + caption = self.process_caption(subset, image_info.caption) if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer) + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) else: - token_caption = self.get_input_ids(caption) - input_ids_list.append(token_caption) + captions.append(caption) + + if not self.token_padding_disabled: # this option might be omitted in future + if self.XTI_layers: + token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + else: + token_caption = self.get_input_ids(caption, self.tokenizers[0]) + input_ids_list.append(token_caption) + + if len(self.tokenizers) > 1: + if self.XTI_layers: + token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + else: + token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + input_ids2_list.append(token_caption2) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + if len(text_encoder_outputs1_list) == 0: + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids + if len(self.tokenizers) > 1: + example["input_ids2"] = self.tokenizer[1]( + captions, padding=True, truncation=True, return_tensors="pt" + ).input_ids + else: + example["input_ids2"] = None + else: + example["input_ids"] = torch.stack(input_ids_list) + example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None + example["text_encoder_outputs1_list"] = None + example["text_encoder_outputs2_list"] = None + example["text_encoder_pool2_list"] = None else: - # batch processing seems to be good - example["input_ids"] = torch.stack(input_ids_list) + example["input_ids"] = None + example["input_ids2"] = None + # # for assertion + # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) + # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) + example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) + example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) + example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) if images[0] is not None: images = torch.stack(images) @@ -1002,10 +1222,76 @@ def __getitem__(self, index): example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None example["captions"] = captions + example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) + example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts]) + example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) + example["flippeds"] = flippeds + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example + def get_item_for_caching(self, bucket, bucket_batch_size, image_index): + captions = [] + images = [] + input_ids1_list = [] + input_ids2_list = [] + absolute_paths = [] + resized_sizes = [] + bucket_reso = None + flip_aug = None + random_crop = None + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + + if flip_aug is None: + flip_aug = subset.flip_aug + random_crop = subset.random_crop + bucket_reso = image_info.bucket_reso + else: + assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert random_crop == subset.random_crop, "random_crop must be same in a batch" + assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" + + caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. + + if self.caching_mode == "latents": + image = load_image(image_info.absolute_path) + else: + image = None + + if self.caching_mode == "text": + input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + else: + input_ids1 = None + input_ids2 = None + + captions.append(caption) + images.append(image) + input_ids1_list.append(input_ids1) + input_ids2_list.append(input_ids2) + absolute_paths.append(image_info.absolute_path) + resized_sizes.append(image_info.resized_size) + + example = {} + + if images[0] is None: + images = None + example["images"] = images + + example["captions"] = captions + example["input_ids1_list"] = input_ids1_list + example["input_ids2_list"] = input_ids2_list + example["absolute_paths"] = absolute_paths + example["resized_sizes"] = resized_sizes + example["flip_aug"] = flip_aug + example["random_crop"] = random_crop + example["bucket_reso"] = bucket_reso + return example + class DreamBoothDataset(BaseDataset): def __init__( @@ -1385,6 +1671,168 @@ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): return npz_file_norm, npz_file_flip +class ControlNetDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[ControlNetSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + db_subsets = [] + for subset in subsets: + db_subset = DreamBoothSubset( + subset.image_dir, + False, + None, + subset.caption_extension, + subset.num_repeats, + subset.shuffle_caption, + subset.keep_tokens, + subset.color_aug, + subset.flip_aug, + subset.face_crop_aug_range, + subset.random_crop, + subset.caption_dropout_rate, + subset.caption_dropout_every_n_epochs, + subset.caption_tag_dropout_rate, + subset.caption_prefix, + subset.caption_suffix, + subset.token_warmup_min, + subset.token_warmup_step, + ) + db_subsets.append(db_subset) + + self.dreambooth_dataset_delegate = DreamBoothDataset( + db_subsets, + batch_size, + tokenizer, + max_token_length, + resolution, + enable_bucket, + min_bucket_reso, + max_bucket_reso, + bucket_reso_steps, + bucket_no_upscale, + 1.0, + debug_dataset, + ) + + # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) + self.image_data = self.dreambooth_dataset_delegate.image_data + self.batch_size = batch_size + self.num_train_images = self.dreambooth_dataset_delegate.num_train_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + + # assert all conditioning data exists + missing_imgs = [] + cond_imgs_with_img = set() + for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): + db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] + subset = None + for s in subsets: + if s.image_dir == db_subset.image_dir: + subset = s + break + assert subset is not None, "internal error: subset not found" + + if not os.path.isdir(subset.conditioning_data_dir): + print(f"not directory: {subset.conditioning_data_dir}") + continue + + img_basename = os.path.basename(info.absolute_path) + ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) + if not os.path.exists(ctrl_img_path): + missing_imgs.append(img_basename) + + info.cond_img_path = ctrl_img_path + cond_imgs_with_img.add(ctrl_img_path) + + extra_imgs = [] + for subset in subsets: + conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") + extra_imgs.extend( + [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] + ) + + assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + + self.conditioning_image_transforms = IMAGE_TRANSFORMS + + def make_buckets(self): + self.dreambooth_dataset_delegate.make_buckets() + self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager + self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + + def __len__(self): + return self.dreambooth_dataset_delegate.__len__() + + def __getitem__(self, index): + example = self.dreambooth_dataset_delegate[index] + + bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[ + self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index + ] + bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size + image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size + + conditioning_images = [] + + for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]): + image_info = self.dreambooth_dataset_delegate.image_data[image_key] + + target_size_hw = example["target_sizes_hw"][i] + original_size_hw = example["original_sizes_hw"][i] + crop_top_left = example["crop_top_lefts"][i] + flipped = example["flippeds"][i] + cond_img = load_image(image_info.cond_img_path) + + if self.dreambooth_dataset_delegate.enable_bucket: + assert ( + cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] + ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" + cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + # TODO support random crop + # 現在サポートしているcropはrandomではなく中央のみ + h, w = target_size_hw + ct = (cond_img.shape[0] - h) // 2 + cl = (cond_img.shape[1] - w) // 2 + cond_img = cond_img[ct : ct + h, cl : cl + w] + else: + # assert ( + # cond_img.shape[0] == self.height and cond_img.shape[1] == self.width + # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + # resize to target + if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: + cond_img = cv2.resize( + cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 + ) + + if flipped: + cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride + + cond_img = self.conditioning_image_transforms(cond_img) + conditioning_images.append(cond_img) + + example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() + + return example + + # behave as Dataset mock class DatasetGroup(torch.utils.data.ConcatDataset): def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): @@ -1421,9 +1869,27 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc print(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def cache_text_encoder_outputs( + self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + ): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + + def set_caching_mode(self, caching_mode): + for dataset in self.datasets: + dataset.set_caching_mode(caching_mode) + + def verify_bucket_reso_steps(self, min_steps: int): + for dataset in self.datasets: + dataset.verify_bucket_reso_steps(min_steps) + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + def is_text_encoder_output_cacheable(self) -> bool: + return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -1441,6 +1907,55 @@ def disable_token_padding(self): dataset.disable_token_padding() +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): + expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 + + if not os.path.exists(npz_path): + return False + + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? + return False + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + return True + + +# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +def load_latents_from_disk( + npz_path, +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: + npz = np.load(npz_path) + if "latents" not in npz: + raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_ltrb = npz["crop_ltrb"].tolist() + flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None + return latents, original_size, crop_ltrb, flipped_latents + + +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): + kwargs = {} + if flipped_latents_tensor is not None: + kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + np.savez( + npz_path, + latents=latents_tensor.float().cpu().numpy(), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), + **kwargs, + ) + + def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") @@ -1462,18 +1977,42 @@ def debug_dataset(train_dataset, show_input_ids=False): example = train_dataset[idx] if example["latents"] is not None: print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate( - zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( + zip( + example["image_keys"], + example["captions"], + example["loss_weights"], + example["input_ids"], + example["original_sizes_hw"], + example["crop_top_lefts"], + example["target_sizes_hw"], + example["flippeds"], + ) ): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + print( + f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' + ) + if show_input_ids: print(f"input ids: {iid}") + if "input_ids2" in example: + print(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] print(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + + if "conditioning_images" in example: + cond_img = example["conditioning_images"][j] + print(f"conditioning image size: {cond_img.size()}") + cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) + cond_img = np.transpose(cond_img, (1, 2, 0)) + cond_img = cond_img[:, :, ::-1] + if os.name == "nt": + cv2.imshow("cond_img", cond_img) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() @@ -1534,6 +2073,9 @@ def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False) self.is_reg = False self.image_dir = "dummy" # for metadata + def verify_bucket_reso_steps(self, min_steps: int): + pass + def is_latent_cacheable(self) -> bool: return False @@ -1588,6 +2130,149 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group +def load_image(image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + +# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) +def trim_and_resize_if_required( + random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] +) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: + image_height, image_width = image.shape[0:2] + original_size = (image_width, image_height) # size before resize + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサイズする + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + image_height, image_width = image.shape[0:2] + + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p : p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p : p + reso[1]] + + # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない + # I have no idea how to reflect the cropped value in crop left/top in the case of random crop + + crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size) + + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image, original_size, crop_ltrb + + +def cache_batch_latents( + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool +) -> None: + r""" + requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz + optionally requires image_infos to have: image + if cache_to_disk is True, set info.latents_npz + flipped latents is also saved if flip_aug is True + if cache_to_disk is False, set info.latents + latents_flipped is also set if flip_aug is True + latents_original_size and latents_crop_ltrb are also set + """ + images = [] + for info in image_infos: + image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image = IMAGE_TRANSFORMS(image) + images.append(image) + + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + + if flip_aug: + img_tensors = torch.flip(img_tensors, dims=[3]) + with torch.no_grad(): + flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + # check NaN + if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): + raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") + + if cache_to_disk: + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + + # FIXME this slows down caching a lot, specify this as an option + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype +): + input_ids1 = input_ids1.to(text_encoders[0].device) + input_ids2 = input_ids2.to(text_encoders[1].device) + + with torch.no_grad(): + b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( + max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + dtype, + ) + + # ここでcpuに移動しておかないと、上書きされてしまう + b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 + b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280 + b_pool2 = b_pool2.detach().to("cpu") # b,1280 + + for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2) + else: + info.text_encoder_outputs1 = hidden_state1 + info.text_encoder_outputs2 = hidden_state2 + info.text_encoder_pool2 = pool2 + + +def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): + np.savez( + npz_path, + hidden_state1=hidden_state1.cpu().float().numpy(), + hidden_state2=hidden_state2.cpu().float().numpy(), + pool2=pool2.cpu().float().numpy(), + ) + + +def load_text_encoder_outputs_from_disk(npz_path): + with np.load(npz_path) as f: + hidden_state1 = torch.from_numpy(f["hidden_state1"]) + hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None + pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None + return hidden_state1, hidden_state2, pool2 + + # endregion # region モジュール入れ替え部 @@ -1676,276 +2361,98 @@ def addnet_hash_legacy(b): return m.hexdigest()[0:8] -def addnet_hash_safetensors(b): - """New model hash used by sd-webui-additional-networks for .safetensors format files""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 - - b.seek(0) - header = b.read(8) - n = int.from_bytes(header, "little") - - offset = n + 8 - b.seek(offset) - for chunk in iter(lambda: b.read(blksize), b""): - hash_sha256.update(chunk) - - return hash_sha256.hexdigest() - - -def get_git_revision_hash() -> str: - try: - return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() - except: - return "(unknown)" - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.function.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - # unet is not used currently, but it is here for future use - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - h = self.heads - q = self.to_q(x) + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") - context = context if context is not None else x - context = context.to(x.dtype) + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + return hash_sha256.hexdigest() - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) +def get_git_revision_hash() -> str: + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() + except: + return "(unknown)" - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, "b h n d -> b n (h d)") +# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): +# replace_attentions_for_hypernetwork() +# # unet is not used currently, but it is here for future use +# unet.enable_xformers_memory_efficient_attention() +# return +# if mem_eff_attn: +# unet.set_attn_processor(FlashAttnProcessor()) +# elif xformers: +# unet.enable_xformers_memory_efficient_attention() - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn +# def replace_unet_cross_attn_to_xformers(): +# print("CrossAttention.forward has been replaced to enable xformers.") +# try: +# import xformers.ops +# except ImportError: +# raise ImportError("No xformers / xformersがインストールされていないようです") +# def forward_xformers(self, x, context=None, mask=None): +# h = self.heads +# q_in = self.to_q(x) -def replace_unet_cross_attn_to_xformers(): - print("CrossAttention.forward has been replaced to enable xformers.") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") +# context = default(context, x) +# context = context.to(x.dtype) - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) +# if hasattr(self, "hypernetwork") and self.hypernetwork is not None: +# context_k, context_v = self.hypernetwork.forward(x, context) +# context_k = context_k.to(x.dtype) +# context_v = context_v.to(x.dtype) +# else: +# context_k = context +# context_v = context - context = default(context, x) - context = context.to(x.dtype) +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context +# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) +# del q_in, k_in, v_in - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in +# out = rearrange(out, "b n h d -> b n (h d)", h=h) - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる +# # diffusers 0.7.0~ +# out = self.to_out[0](out) +# out = self.to_out[1](out) +# return out - out = rearrange(out, "b n h d -> b n (h d)", h=h) - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out +# diffusers.models.attention.CrossAttention.forward = forward_xformers +def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - diffusers.models.attention.CrossAttention.forward = forward_xformers + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_sdpa(True) """ @@ -2008,6 +2515,106 @@ def forward_flash_attn(self, hidden_states): # region arguments +def load_metadata_from_safetensors(safetensors_file: str) -> dict: + """r + This method locks the file. see https://github.com/huggingface/safetensors/issues/164 + If the file isn't .safetensors or doesn't have metadata, return empty dict. + """ + if os.path.splitext(safetensors_file)[1] != ".safetensors": + return {} + + with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + return metadata + + +# this metadata is referred from train_network and various scripts, so we wrote here +SS_METADATA_KEY_V2 = "ss_v2" +SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version" +SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module" +SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim" +SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha" +SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args" + +SS_METADATA_MINIMUM_KEYS = [ + SS_METADATA_KEY_V2, + SS_METADATA_KEY_BASE_MODEL_VERSION, + SS_METADATA_KEY_NETWORK_MODULE, + SS_METADATA_KEY_NETWORK_DIM, + SS_METADATA_KEY_NETWORK_ALPHA, + SS_METADATA_KEY_NETWORK_ARGS, +] + + +def build_minimum_network_metadata( + v2: Optional[bool], + base_model: Optional[str], + network_module: str, + network_dim: str, + network_alpha: str, + network_args: Optional[dict], +): + # old LoRA doesn't have base_model + metadata = { + SS_METADATA_KEY_NETWORK_MODULE: network_module, + SS_METADATA_KEY_NETWORK_DIM: network_dim, + SS_METADATA_KEY_NETWORK_ALPHA: network_alpha, + } + if v2 is not None: + metadata[SS_METADATA_KEY_V2] = v2 + if base_model is not None: + metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model + if network_args is not None: + metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args) + return metadata + + +def get_sai_model_spec( + state_dict: dict, + args: argparse.Namespace, + sdxl: bool, + lora: bool, + textual_inversion: bool, + is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA +): + timestamp = time.time() + + v2 = args.v2 + v_parameterization = args.v_parameterization + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + metadata = sai_model_spec.build_metadata( + state_dict, + v2, + v_parameterization, + sdxl, + lora, + textual_inversion, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, # None or int + ) + return metadata + + def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") @@ -2195,6 +2802,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument( + "--sdpa", + action="store_true", + help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", + ) parser.add_argument( "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" ) @@ -2231,6 +2843,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument( + "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" + ) # TODO move to SDXL training, because it is not supported by SD1/2 parser.add_argument( "--clip_skip", type=int, @@ -2257,6 +2872,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", ) + parser.add_argument( + "--log_tracker_config", + type=str, + default=None, + help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", + ) parser.add_argument( "--wandb_api_key", type=str, @@ -2275,6 +2896,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", ) + parser.add_argument( + "--ip_noise_gamma", + type=float, + default=None, + help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -2293,6 +2921,24 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)", ) + parser.add_argument( + "--zero_terminal_snr", + action="store_true", + help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する", + ) + parser.add_argument( + "--min_timestep", + type=int, + default=None, + help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", + ) + parser.add_argument( + "--max_timestep", + type=int, + default=None, + help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", + ) + parser.add_argument( "--lowram", action="store_true", @@ -2346,6 +2992,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" ) + # SAI Model spec + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + if support_dreambooth: # DreamBooth training parser.add_argument( @@ -2355,7 +3033,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -2366,11 +3044,11 @@ def verify_training_args(args: argparse.Namespace): ) # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time - # Listを使って数えてもいいけど並べてしまえ - if args.noise_offset is not None and args.multires_noise_iterations is not None: - raise ValueError( - "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" - ) + # # Listを使って数えてもいいけど並べてしまえ + # if args.noise_offset is not None and args.multires_noise_iterations is not None: + # raise ValueError( + # "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" + # ) # if args.noise_offset is not None and args.perlin_noise is not None: # raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません") # if args.perlin_noise is not None and args.multires_noise_iterations is not None: @@ -2386,6 +3064,17 @@ def verify_training_args(args: argparse.Namespace): "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + if args.v_pred_like_loss and args.v_parameterization: + raise ValueError( + "v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません" + ) + + if args.zero_terminal_snr and not args.v_parameterization: + print( + f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" + ) + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool @@ -2410,6 +3099,18 @@ def add_dataset_arguments( default=0, help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", ) + parser.add_argument( + "--caption_prefix", + type=str, + default=None, + help="prefix for caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--caption_suffix", + type=str, + default=None, + help="suffix for caption text / captionのテキストの末尾に付ける文字列", + ) parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") parser.add_argument( @@ -2693,32 +3394,9 @@ def get_optimizer(args, trainable_params): # print("optkwargs:", optimizer_kwargs) lr = args.learning_rate + optimizer = None - if optimizer_type == "AdamW8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - print( - f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" - ) - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "Lion".lower(): + if optimizer_type == "Lion".lower(): try: import lion_pytorch except ImportError: @@ -2726,37 +3404,53 @@ def get_optimizer(args, trainable_params): print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - + elif optimizer_type.endswith("8bit".lower()): try: import bitsandbytes as bnb except ImportError: raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") - if optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.Lion8bit - except AttributeError: - raise AttributeError( - "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" - ) + if optimizer_type == "AdamW8bit".lower(): + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion8bit".lower(): + print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.Lion8bit + except AttributeError: + raise AttributeError( + "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" + ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdamW8bit - except AttributeError: - raise AttributeError( - "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdamW8bit + except AttributeError: + raise AttributeError( + "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedLion8bit - except AttributeError: - raise AttributeError( - "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedLion8bit + except AttributeError: + raise AttributeError( + "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -2888,7 +3582,7 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - else: + if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) print(f"use {optimizer_type} | {optimizer_kwargs}") @@ -2908,10 +3602,8 @@ def get_optimizer(args, trainable_params): return optimizer_name, optimizer_args, optimizer -# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler -# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 -# Which is a newer release of diffusers than currently packaged with sd-scripts -# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts +# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler +# Add some checking and features to the original function. def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): @@ -2920,7 +3612,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ name = args.lr_scheduler num_warmup_steps: Optional[int] = args.lr_warmup_steps - num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps + num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -2928,19 +3620,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: for arg in args.lr_scheduler_args: key, value = arg.split("=") - value = ast.literal_eval(value) - # value = value.split(",") - # for i in range(len(value)): - # if value[i].lower() == "true" or value[i].lower() == "false": - # value[i] = value[i].lower() == "true" - # else: - # value[i] = ast.literal_eval(value[i]) - # if len(value) == 1: - # value = value[0] - # else: - # value = list(value) # some may use list? - lr_scheduler_kwargs[key] = value def wrap_check_needless_num_warmup_steps(return_vals): @@ -2972,15 +3652,19 @@ def wrap_check_needless_num_warmup_steps(return_vals): name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: - return wrap_check_needless_num_warmup_steps(schedule_func(optimizer)) + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) + + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) # All other schedulers require `num_training_steps` if num_training_steps is None: @@ -2988,13 +3672,19 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + **lr_scheduler_kwargs, ) if name == SchedulerType.POLYNOMIAL: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs + ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): @@ -3086,23 +3776,9 @@ def prepare_accelerator(args: argparse.Namespace): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, - logging_dir=logging_dir, + project_dir=logging_dir, ) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - return accelerator, unwrap_model + return accelerator def prepare_dtype(args: argparse.Namespace): @@ -3123,13 +3799,15 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): +def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): name_or_path = args.pretrained_model_name_or_path name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print(f"load StableDiffusion checkpoint: {name_or_path}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( + args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 + ) else: # Diffusers model is loaded to CPU print(f"load Diffusers pretrained models: {name_or_path}") @@ -3139,11 +3817,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): print( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) + raise ex text_encoder = pipe.text_encoder vae = pipe.vae unet = pipe.unet del pipe + # Diffusers U-Net to original U-Net + # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう + # print(f"unet config: {unet.config}") + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + print("U-Net converted to original U-Net") + # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) @@ -3152,19 +3845,28 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): return text_encoder, vae, unet, load_stable_diffusion_format +# TODO remove this function in the future def transform_if_model_is_DDP(text_encoder, unet, network=None): # Transform text_encoder, unet and network from DistributedDataParallel return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) -def load_target_model(args, weight_dtype, accelerator): +def transform_models_if_DDP(models): + # Transform text_encoder, unet and network from DistributedDataParallel + return [model.module if type(model) == DDP else model for model in models if model is not None] + + +def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( - args, weight_dtype, accelerator.device if args.lowram else "cpu" + args, + weight_dtype, + accelerator.device if args.lowram else "cpu", + unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, ) # work on low-ram device @@ -3196,6 +3898,7 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod if input_ids.size()[-1] != tokenizer.model_max_length: return text_encoder(input_ids)[0] + # input_ids: b,n,77 b_size = input_ids.size()[0] input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 @@ -3237,6 +3940,114 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod return encoder_hidden_states +def pool_workaround( + text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int +): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + +def get_hidden_states_sdxl( + max_token_length: int, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: CLIPTextModel, + text_encoder2: CLIPTextModelWithProjection, + weight_dtype: Optional[str] = None, +): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + if weight_dtype is not None: + # this is required for additional network training + hidden_states1 = hidden_states1.to(weight_dtype) + hidden_states2 = hidden_states2.to(weight_dtype) + + return hidden_states1, hidden_states2, pool2 + + def default_if_none(value, default): return default if value is None else value @@ -3295,6 +4106,43 @@ def save_sd_model_on_epoch_end_or_stepwise( text_encoder, unet, vae, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae + ) + + def diffusers_saver(out_dir): + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + sd_saver, + diffusers_saver, + ) + + +def save_sd_model_on_epoch_end_or_stepwise_common( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_stable_diffusion_format: bool, + use_safetensors: bool, + epoch: int, + num_train_epochs: int, + global_step: int, + sd_saver, + diffusers_saver, ): if on_epoch_end: epoch_no = epoch + 1 @@ -3322,9 +4170,7 @@ def save_sd_model_on_epoch_end_or_stepwise( ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"\nsaving checkpoint: {ckpt_file}") - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae - ) + sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) @@ -3348,9 +4194,8 @@ def save_sd_model_on_epoch_end_or_stepwise( out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) print(f"\nsaving model: {out_dir}") - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) + diffusers_saver(out_dir) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name) @@ -3443,6 +4288,31 @@ def save_sd_model_on_train_end( text_encoder, unet, vae, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae + ) + + def diffusers_saver(out_dir): + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver + ) + + +def save_sd_model_on_train_end_common( + args: argparse.Namespace, + save_stable_diffusion_format: bool, + use_safetensors: bool, + epoch: int, + global_step: int, + sd_saver, + diffusers_saver, ): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) @@ -3453,9 +4323,8 @@ def save_sd_model_on_train_end( ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae - ) + sd_saver(ckpt_file, epoch, global_step) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) else: @@ -3463,13 +4332,40 @@ def save_sd_model_on_train_end( os.makedirs(out_dir, exist_ok=True) print(f"save trained model as Diffusers to {out_dir}") - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) + diffusers_saver(out_dir) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.multires_noise_iterations: + noise = custom_train_functions.pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + b_size = latents.shape[0] + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep + + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + return noise, noisy_latents, timesteps + + # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -3477,8 +4373,23 @@ def save_sd_model_on_train_end( SCHEDLER_SCHEDULE = "scaled_linear" -def sample_images( - accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None +def sample_images(*args, **kwargs): + return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) + + +def sample_images_common( + pipe_class, + accelerator, + args: argparse.Namespace, + epoch, + steps, + device, + vae, + tokenizer, + text_encoder, + unet, + prompt_replacement=None, + controlnet=None, ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した @@ -3562,16 +4473,16 @@ def sample_images( # print("set clip_sample to True") scheduler.config.clip_sample = True - pipeline = StableDiffusionLongPromptWeightingPipeline( + pipeline = pipe_class( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, - clip_skip=args.clip_skip, safety_checker=None, feature_extractor=None, requires_safety_checker=False, + clip_skip=args.clip_skip, ) pipeline.to(device) @@ -3582,114 +4493,130 @@ def sample_images( cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None with torch.no_grad(): - with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue + # with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue - - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - image = pipeline( + if isinstance(prompt, dict): + negative_prompt = prompt.get("negative_prompt") + sample_steps = prompt.get("sample_steps", 30) + width = prompt.get("width", 512) + height = prompt.get("height", 512) + scale = prompt.get("scale", 7.5) + seed = prompt.get("seed") + controlnet_image = prompt.get("controlnet_image") + prompt = prompt.get("prompt") + else: + # prompt = prompt.strip() + # if len(prompt) == 0 or prompt[0] == "#": + # continue + + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + controlnet_image = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + controlnet_image = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + with accelerator.autocast(): + latents = pipeline( prompt=prompt, height=height, width=width, num_inference_steps=sample_steps, guidance_scale=scale, negative_prompt=negative_prompt, - ).images[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + controlnet=controlnet, + controlnet_image=controlnet_image, ) - image.save(os.path.join(save_dir, img_filename)) + image = pipeline.latents_to_image(latents)[0] + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) - # wandb有効時のみログを送信 + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # clear pipeline and cache to reduce vram usage del pipeline diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index bb8dcd6ba..51f581b29 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -5,35 +5,41 @@ def main(file): - print(f"loading: {file}") - if os.path.splitext(file)[1] == '.safetensors': - sd = load_file(file) - else: - sd = torch.load(file, map_location='cpu') + print(f"loading: {file}") + if os.path.splitext(file)[1] == ".safetensors": + sd = load_file(file) + else: + sd = torch.load(file, map_location="cpu") - values = [] + values = [] - keys = list(sd.keys()) - for key in keys: - if 'lora_up' in key or 'lora_down' in key: - values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") + keys = list(sd.keys()) + for key in keys: + if "lora_up" in key or "lora_down" in key: + values.append((key, sd[key])) + print(f"number of LoRA modules: {len(values)}") - for key, value in values: - value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + if args.show_all_keys: + for key in [k for k in keys if k not in values]: + values.append((key, sd[key])) + print(f"number of all modules: {len(values)}") + + for key, value in values: + value = value.to(torch.float32) + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() + args = parser.parse_args() - main(args.file) + main(args.file) diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py new file mode 100644 index 000000000..4ebfef7a4 --- /dev/null +++ b/networks/control_net_lllite.py @@ -0,0 +1,446 @@ +import os +from typing import Optional, List, Type +import torch +from library import sdxl_original_unet + + +# input_blocksに適用するかどうか / if True, input_blocks are not applied +SKIP_INPUT_BLOCKS = False + +# output_blocksに適用するかどうか / if True, output_blocks are not applied +SKIP_OUTPUT_BLOCKS = True + +# conv2dに適用するかどうか / if True, conv2d are not applied +SKIP_CONV2D = False + +# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない +# if True, only transformer_blocks are applied, and ResBlocks are not applied +TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks + +# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. +ATTN1_2_ONLY = True + +# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified +ATTN_QKV_ONLY = True + +# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY +ATTN1_ETC_ONLY = False # True + +# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 +# max index of transformer_blocks. if None, apply to all transformer_blocks +TRANSFORMER_MAX_BLOCK_INDEX = None + + +class LLLiteModule(torch.nn.Module): + def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0): + super().__init__() + + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.lllite_name = name + self.cond_emb_dim = cond_emb_dim + self.org_module = [org_module] + self.dropout = dropout + self.multiplier = multiplier + + if self.is_conv2d: + in_dim = org_module.in_channels + else: + in_dim = org_module.in_features + + # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない + # conditioning1 embeds conditioning image. it is not called for each timestep + modules = [] + modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + + self.conditioning1 = torch.nn.Sequential(*modules) + + # downで入力の次元数を削減する。LoRAにヒントを得ていることにする + # midでconditioning image embeddingと入力を結合する + # upで元の次元数に戻す + # これらはtimestepごとに呼ばれる + # reduce the number of input dimensions with down. inspired by LoRA + # combine conditioning image embedding and input with mid + # restore to the original dimension with up + # these are called for each timestep + + if self.is_conv2d: + self.down = torch.nn.Sequential( + torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), + ) + else: + # midの前にconditioningをreshapeすること / reshape conditioning before mid + self.down = torch.nn.Sequential( + torch.nn.Linear(in_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Linear(mlp_dim, in_dim), + ) + + # Zero-Convにする / set to Zero-Conv + torch.nn.init.zeros_(self.up[0].weight) # zero conv + + self.depth = depth # 1~3 + self.cond_emb = None + self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference + self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 + + # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない + # Controlの種類によっては使えるかも + # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice + # it may be available depending on the type of Control + + def set_cond_image(self, cond_image): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ + if cond_image is None: + self.cond_emb = None + return + + # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance + # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + cx = self.conditioning1(cond_image) + if not self.is_conv2d: + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + self.cond_emb = cx + + def set_batch_cond_only(self, cond_only, zeros): + self.batch_cond_only = cond_only + self.use_zeros_for_batch_uncond = zeros + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def forward(self, x): + r""" + 学習用の便利forward。元のモジュールのforwardを呼び出す + / convenient forward for training. call the forward of the original module + """ + if self.multiplier == 0.0 or self.cond_emb is None: + return self.org_forward(x) + + cx = self.cond_emb + + if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only + cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) + if self.use_zeros_for_batch_uncond: + cx[0::2] = 0.0 # uncond is zero + # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + + # downで入力の次元数を削減し、conditioning image embeddingと結合する + # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している + # down reduces the number of input dimensions and combines it with conditioning image embedding + # we expect that it will mix well by combining in the channel direction instead of adding + + cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2) + cx = self.mid(cx) + + if self.dropout is not None and self.training: + cx = torch.nn.functional.dropout(cx, p=self.dropout) + + cx = self.up(cx) * self.multiplier + + # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward + if self.batch_cond_only: + zx = torch.zeros_like(x) + zx[1::2] += cx + cx = zx + + x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here + return x + + +class ControlNetLLLite(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + + def __init__( + self, + unet: sdxl_original_unet.SdxlUNet2DConditionModel, + cond_emb_dim: int = 16, + mlp_dim: int = 16, + dropout: Optional[float] = None, + varbose: Optional[bool] = False, + multiplier: Optional[float] = 1.0, + ) -> None: + super().__init__() + # self.unets = [unet] + + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + module_class: Type[object], + ) -> List[torch.nn.Module]: + prefix = "lllite_unet" + + modules = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + + if is_linear or (is_conv2d and not SKIP_CONV2D): + # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う + # block index to depth: depth is using to calculate conditioning size and channels + block_name, index1, index2 = (name + "." + child_name).split(".")[:3] + index1 = int(index1) + if block_name == "input_blocks": + if SKIP_INPUT_BLOCKS: + continue + depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) + elif block_name == "middle_block": + depth = 3 + elif block_name == "output_blocks": + if SKIP_OUTPUT_BLOCKS: + continue + depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) + if int(index2) >= 2: + depth -= 1 + else: + raise NotImplementedError() + + lllite_name = prefix + "." + name + "." + child_name + lllite_name = lllite_name.replace(".", "_") + + if TRANSFORMER_MAX_BLOCK_INDEX is not None: + p = lllite_name.find("transformer_blocks") + if p >= 0: + tf_index = int(lllite_name[p:].split("_")[2]) + if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: + continue + + # time embは適用外とする + # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない + # time emb is not applied + # attn2 conditioning (input from CLIP) cannot be applied because the shape is different + if "emb_layers" in lllite_name or ( + "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) + ): + continue + + if ATTN1_2_ONLY: + if not ("attn1" in lllite_name or "attn2" in lllite_name): + continue + if ATTN_QKV_ONLY: + if "to_out" in lllite_name: + continue + + if ATTN1_ETC_ONLY: + if "proj_out" in lllite_name: + pass + elif "attn1" in lllite_name and ( + "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name + ): + pass + elif "ff_net_2" in lllite_name: + pass + else: + continue + + module = module_class( + depth, + cond_emb_dim, + lllite_name, + child_module, + mlp_dim, + dropout=dropout, + multiplier=multiplier, + ) + modules.append(module) + return modules + + target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE + if not TRANSFORMER_ONLY: + target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + # create module instances + self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) + print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + + def forward(self, x): + return x # dummy + + def set_cond_image(self, cond_image): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ + for module in self.unet_modules: + module.set_cond_image(cond_image) + + def set_batch_cond_only(self, cond_only, zeros): + for module in self.unet_modules: + module.set_batch_cond_only(cond_only, zeros) + + def set_multiplier(self, multiplier): + for module in self.unet_modules: + module.multiplier = multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self): + print("applying LLLite for U-Net...") + for module in self.unet_modules: + module.apply_to() + self.add_module(module.lllite_name, module) + + # マージできるかどうかを返す + def is_mergeable(self): + return False + + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + raise NotImplementedError() + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_optimizer_params(self): + self.requires_grad_(True) + return self.parameters() + + def prepare_grad_etc(self): + self.requires_grad_(True) + + def on_epoch_start(self): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + +if __name__ == "__main__": + # デバッグ用 / for debug + + # sdxl_original_unet.USE_REENTRANT = False + + # test shape etc + print("create unet") + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + unet.to("cuda").to(torch.float16) + + print("create ControlNet-LLLite") + control_net = ControlNetLLLite(unet, 32, 64) + control_net.apply_to() + control_net.to("cuda") + + print(control_net) + + # print number of parameters + print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + + input() + + unet.set_use_memory_efficient_attention(True, False) + unet.set_gradient_checkpointing(True) + unet.train() # for gradient checkpointing + + control_net.train() + + # # visualize + # import torchviz + # print("run visualize") + # controlnet.set_control(conditioning_image) + # output = unet(x, t, ctx, y) + # print("make_dot") + # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) + # print("render") + # image.format = "svg" # "png" + # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time + # input() + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3) + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + print("start training") + steps = 10 + + sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] + for step in range(steps): + print(f"step {step}") + + batch_size = 1 + conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 + x = torch.randn(batch_size, 4, 128, 128).cuda() + t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() + ctx = torch.randn(batch_size, 77, 2048).cuda() + y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + + with torch.cuda.amp.autocast(enabled=True): + control_net.set_cond_image(conditioning_image) + + output = unet(x, t, ctx, y) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + print(sample_param) + + # from safetensors.torch import save_file + + # save_file(control_net.state_dict(), "logs/control_net.safetensors") diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py new file mode 100644 index 000000000..026880015 --- /dev/null +++ b/networks/control_net_lllite_for_train.py @@ -0,0 +1,502 @@ +# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装 +# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward + +import os +import re +from typing import Optional, List, Type +import torch +from library import sdxl_original_unet + + +# input_blocksに適用するかどうか / if True, input_blocks are not applied +SKIP_INPUT_BLOCKS = False + +# output_blocksに適用するかどうか / if True, output_blocks are not applied +SKIP_OUTPUT_BLOCKS = True + +# conv2dに適用するかどうか / if True, conv2d are not applied +SKIP_CONV2D = False + +# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない +# if True, only transformer_blocks are applied, and ResBlocks are not applied +TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks + +# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. +ATTN1_2_ONLY = True + +# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified +ATTN_QKV_ONLY = True + +# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY +ATTN1_ETC_ONLY = False # True + +# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 +# max index of transformer_blocks. if None, apply to all transformer_blocks +TRANSFORMER_MAX_BLOCK_INDEX = None + +ORIGINAL_LINEAR = torch.nn.Linear +ORIGINAL_CONV2D = torch.nn.Conv2d + + +def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None: + # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない + # conditioning1 embeds conditioning image. it is not called for each timestep + modules = [] + modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + + module.lllite_conditioning1 = torch.nn.Sequential(*modules) + + # downで入力の次元数を削減する。LoRAにヒントを得ていることにする + # midでconditioning image embeddingと入力を結合する + # upで元の次元数に戻す + # これらはtimestepごとに呼ばれる + # reduce the number of input dimensions with down. inspired by LoRA + # combine conditioning image embedding and input with mid + # restore to the original dimension with up + # these are called for each timestep + + module.lllite_down = torch.nn.Sequential( + ORIGINAL_LINEAR(in_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + module.lllite_mid = torch.nn.Sequential( + ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + module.lllite_up = torch.nn.Sequential( + ORIGINAL_LINEAR(mlp_dim, in_dim), + ) + + # Zero-Convにする / set to Zero-Conv + torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv + + +class LLLiteLinear(ORIGINAL_LINEAR): + def __init__(self, in_features: int, out_features: int, **kwargs): + super().__init__(in_features, out_features, **kwargs) + self.enabled = False + + def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0): + self.enabled = True + self.lllite_name = name + self.cond_emb_dim = cond_emb_dim + self.dropout = dropout + self.multiplier = multiplier # ignored + + in_dim = self.in_features + add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) + + self.cond_image = None + self.cond_emb = None + + def set_cond_image(self, cond_image): + self.cond_image = cond_image + self.cond_emb = None + + def forward(self, x): + if not self.enabled: + return super().forward(x) + + if self.cond_emb is None: + self.cond_emb = self.lllite_conditioning1(self.cond_image) + cx = self.cond_emb + + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + + cx = torch.cat([cx, self.lllite_down(x)], dim=2) + cx = self.lllite_mid(cx) + + if self.dropout is not None and self.training: + cx = torch.nn.functional.dropout(cx, p=self.dropout) + + cx = self.lllite_up(cx) * self.multiplier + + x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here + return x + + +class LLLiteConv2d(ORIGINAL_CONV2D): + def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, **kwargs) + self.enabled = False + + def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0): + self.enabled = True + self.lllite_name = name + self.cond_emb_dim = cond_emb_dim + self.dropout = dropout + self.multiplier = multiplier # ignored + + in_dim = self.in_channels + add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) + + self.cond_image = None + self.cond_emb = None + + def set_cond_image(self, cond_image): + self.cond_image = cond_image + self.cond_emb = None + + def forward(self, x): # , cond_image=None): + if not self.enabled: + return super().forward(x) + + if self.cond_emb is None: + self.cond_emb = self.lllite_conditioning1(self.cond_image) + cx = self.cond_emb + + cx = torch.cat([cx, self.down(x)], dim=1) + cx = self.mid(cx) + + if self.dropout is not None and self.training: + cx = torch.nn.functional.dropout(cx, p=self.dropout) + + cx = self.up(cx) * self.multiplier + + x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here + return x + + +class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + LLLITE_PREFIX = "lllite_unet" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def apply_lllite( + self, + cond_emb_dim: int = 16, + mlp_dim: int = 16, + dropout: Optional[float] = None, + varbose: Optional[bool] = False, + multiplier: Optional[float] = 1.0, + ) -> None: + def apply_to_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[torch.nn.Module]: + prefix = "lllite_unet" + + modules = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "LLLiteLinear" + is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d" + + if is_linear or (is_conv2d and not SKIP_CONV2D): + # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う + # block index to depth: depth is using to calculate conditioning size and channels + block_name, index1, index2 = (name + "." + child_name).split(".")[:3] + index1 = int(index1) + if block_name == "input_blocks": + if SKIP_INPUT_BLOCKS: + continue + depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) + elif block_name == "middle_block": + depth = 3 + elif block_name == "output_blocks": + if SKIP_OUTPUT_BLOCKS: + continue + depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) + if int(index2) >= 2: + depth -= 1 + else: + raise NotImplementedError() + + lllite_name = prefix + "." + name + "." + child_name + lllite_name = lllite_name.replace(".", "_") + + if TRANSFORMER_MAX_BLOCK_INDEX is not None: + p = lllite_name.find("transformer_blocks") + if p >= 0: + tf_index = int(lllite_name[p:].split("_")[2]) + if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: + continue + + # time embは適用外とする + # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない + # time emb is not applied + # attn2 conditioning (input from CLIP) cannot be applied because the shape is different + if "emb_layers" in lllite_name or ( + "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) + ): + continue + + if ATTN1_2_ONLY: + if not ("attn1" in lllite_name or "attn2" in lllite_name): + continue + if ATTN_QKV_ONLY: + if "to_out" in lllite_name: + continue + + if ATTN1_ETC_ONLY: + if "proj_out" in lllite_name: + pass + elif "attn1" in lllite_name and ( + "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name + ): + pass + elif "ff_net_2" in lllite_name: + pass + else: + continue + + child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier) + modules.append(child_module) + + return modules + + target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE + if not TRANSFORMER_ONLY: + target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + # create module instances + self.lllite_modules = apply_to_modules(self, target_modules) + print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") + + # def prepare_optimizer_params(self): + def prepare_params(self): + train_params = [] + non_train_params = [] + for name, p in self.named_parameters(): + if "lllite" in name: + train_params.append(p) + else: + non_train_params.append(p) + print(f"count of trainable parameters: {len(train_params)}") + print(f"count of non-trainable parameters: {len(non_train_params)}") + + for p in non_train_params: + p.requires_grad_(False) + + # without this, an error occurs in the optimizer + # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn + non_train_params[0].requires_grad_(True) + + for p in train_params: + p.requires_grad_(True) + + return train_params + + # def prepare_grad_etc(self): + # self.requires_grad_(True) + + # def on_epoch_start(self): + # self.train() + + def get_trainable_params(self): + return [p[1] for p in self.named_parameters() if "lllite" in p[0]] + + def save_lllite_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + org_state_dict = self.state_dict() + + # copy LLLite keys from org_state_dict to state_dict with key conversion + state_dict = {} + for key in org_state_dict.keys(): + # split with ".lllite" + pos = key.find(".lllite") + if pos < 0: + continue + lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos] + lllite_key = lllite_key.replace(".", "_") + key[pos:] + lllite_key = lllite_key.replace(".lllite_", ".") + state_dict[lllite_key] = org_state_dict[key] + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def load_lllite_weights(self, file, non_lllite_unet_sd=None): + r""" + LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。 + この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。 + + If you do not want to load LLLite weights (use initialized values), specify None for file. + In this case, specify the state_dict of U-Net for non_lllite_unet_sd. + """ + if not file: + state_dict = self.state_dict() + for key in non_lllite_unet_sd: + if key in state_dict: + state_dict[key] = non_lllite_unet_sd[key] + info = self.load_state_dict(state_dict, False) + return info + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # module_name = module_name.replace("_block", "@blocks") + # module_name = module_name.replace("_layer", "@layer") + # module_name = module_name.replace("to_", "to@") + # module_name = module_name.replace("time_embed", "time@embed") + # module_name = module_name.replace("label_emb", "label@emb") + # module_name = module_name.replace("skip_connection", "skip@connection") + # module_name = module_name.replace("proj_in", "proj@in") + # module_name = module_name.replace("proj_out", "proj@out") + pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)") + + # convert to lllite with U-Net state dict + state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {} + for key in weights_sd.keys(): + # split with "." + pos = key.find(".") + if pos < 0: + continue + + module_name = key[:pos] + weight_name = key[pos + 1 :] # exclude "." + module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "") + + # これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion + # module_name = module_name.replace("_", ".") + + # ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@" + matches = pattern.findall(module_name) + if matches is not None: + for m in matches: + print(module_name, m) + module_name = module_name.replace(m, m.replace("_", "@")) + module_name = module_name.replace("_", ".") + module_name = module_name.replace("@", "_") + + lllite_key = module_name + ".lllite_" + weight_name + + state_dict[lllite_key] = weights_sd[key] + + info = self.load_state_dict(state_dict, False) + return info + + def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs): + for m in self.lllite_modules: + m.set_cond_image(cond_image) + return super().forward(x, timesteps, context, y, **kwargs) + + +def replace_unet_linear_and_conv2d(): + print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") + sdxl_original_unet.torch.nn.Linear = LLLiteLinear + sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d + + +if __name__ == "__main__": + # デバッグ用 / for debug + + # sdxl_original_unet.USE_REENTRANT = False + replace_unet_linear_and_conv2d() + + # test shape etc + print("create unet") + unet = SdxlUNet2DConditionModelControlNetLLLite() + + print("enable ControlNet-LLLite") + unet.apply_lllite(32, 64, None, False, 1.0) + unet.to("cuda") # .to(torch.float16) + + # from safetensors.torch import load_file + + # model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors") + # unet_sd = {} + + # # copy U-Net keys from unet_state_dict to state_dict + # prefix = "model.diffusion_model." + # for key in model_sd.keys(): + # if key.startswith(prefix): + # converted_key = key[len(prefix) :] + # unet_sd[converted_key] = model_sd[key] + + # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) + # print(info) + + # print(unet) + + # print number of parameters + params = unet.prepare_params() + print("number of parameters", sum(p.numel() for p in params)) + # print("type any key to continue") + # input() + + unet.set_use_memory_efficient_attention(True, False) + unet.set_gradient_checkpointing(True) + unet.train() # for gradient checkpointing + + # # visualize + # import torchviz + # print("run visualize") + # controlnet.set_control(conditioning_image) + # output = unet(x, t, ctx, y) + # print("make_dot") + # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) + # print("render") + # image.format = "svg" # "png" + # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time + # input() + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3) + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + print("start training") + steps = 10 + batch_size = 1 + + sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] + for step in range(steps): + print(f"step {step}") + + conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 + x = torch.randn(batch_size, 4, 128, 128).cuda() + t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() + ctx = torch.randn(batch_size, 77, 2048).cuda() + y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + output = unet(x, t, ctx, y, conditioning_image) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + print(sample_param) + + # from safetensors.torch import save_file + + # print("save weights") + # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py index 90b509dfc..e5a55d198 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -239,7 +239,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class DyLoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index f001e7eb2..dba7cd4e2 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -3,187 +3,265 @@ # Thanks to cloneofsimo! import argparse +import json import os +import time import torch from safetensors.torch import load_file, save_file from tqdm import tqdm -import library.model_util as model_util +from library import sai_model_spec, model_util, sdxl_model_util import lora CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-6 +MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) - if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) - else: - torch.save(model, file_name) + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name) + else: + torch.save(model, file_name) def svd(args): - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None - - save_dtype = str_to_dtype(args.save_precision) - - print(f"loading SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) - print(f"loading SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) - - # create LoRA network to extract weights: Use dim (rank) as alpha - if args.conv_dim is None: - kwargs = {} - else: - kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} - - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs) - assert len(lora_network_o.text_encoder_loras) == len( - lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " - - # get diffs - diffs = {} - text_encoder_different = False - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): - lora_name = lora_o.lora_name - module_o = lora_o.org_module - module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - - # Text Encoder might be same - if torch.max(torch.abs(diff)) > MIN_DIFF: - text_encoder_different = True - - diff = diff.float() - diffs[lora_name] = diff - - if not text_encoder_different: - print("Text encoder is same. Extract U-Net only.") - lora_network_o.text_encoder_loras = [] + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + assert args.v2 != args.sdxl or ( + not args.v2 and not args.sdxl + ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + if args.v_parameterization is None: + args.v_parameterization = args.v2 + + save_dtype = str_to_dtype(args.save_precision) + + # load models + if not args.sdxl: + print(f"loading original SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + text_encoders_o = [text_encoder_o] + print(f"loading tuned SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + text_encoders_t = [text_encoder_t] + model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + else: + print(f"loading original SDXL model : {args.model_org}") + text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" + ) + text_encoders_o = [text_encoder_o1, text_encoder_o2] + print(f"loading original SDXL model : {args.model_tuned}") + text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" + ) + text_encoders_t = [text_encoder_t1, text_encoder_t2] + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 + + # create LoRA network to extract weights: Use dim (rank) as alpha + if args.conv_dim is None: + kwargs = {} + else: + kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) + assert len(lora_network_o.text_encoder_loras) == len( + lora_network_t.text_encoder_loras + ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " + + # get diffs diffs = {} - - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): - lora_name = lora_o.lora_name - module_o = lora_o.org_module - module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - diff = diff.float() - - if args.device: - diff = diff.to(args.device) - - diffs[lora_name] = diff - - # make LoRA with svd - print("calculating by svd") - lora_weights = {} - with torch.no_grad(): - for lora_name, mat in tqdm(list(diffs.items())): - # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 - conv2d = (len(mat.size()) == 4) - kernel_size = None if not conv2d else mat.size()[2:4] - conv2d_3x3 = conv2d and kernel_size != (1, 1) - - rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim - out_dim, in_dim = mat.size()[0:2] - - if args.device: - mat = mat.to(args.device) - - # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) - rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim - - if conv2d: - if conv2d_3x3: - mat = mat.flatten(start_dim=1) - else: - mat = mat.squeeze() - - U, S, Vh = torch.linalg.svd(mat) - - U = U[:, :rank] - S = S[:rank] - U = U @ torch.diag(S) - - Vh = Vh[:rank, :] - - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.reshape(out_dim, rank, 1, 1) - Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - - U = U.to("cpu").contiguous() - Vh = Vh.to("cpu").contiguous() - - lora_weights[lora_name] = (U, Vh) - - # make state dict for LoRA - lora_sd = {} - for lora_name, (up_weight, down_weight) in lora_weights.items(): - lora_sd[lora_name + '.lora_up.weight'] = up_weight - lora_sd[lora_name + '.lora_down.weight'] = down_weight - lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) - - # load state dict to LoRA and save it - lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) - lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict - - info = lora_network_save.load_state_dict(lora_sd) - print(f"Loading extracted LoRA weights: {info}") - - dir_name = os.path.dirname(args.save_to) - if dir_name and not os.path.exists(dir_name): - os.makedirs(dir_name, exist_ok=True) - - # minimum metadata - metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} - - lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + text_encoder_different = False + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + + # Text Encoder might be same + if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + text_encoder_different = True + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + + diff = diff.float() + diffs[lora_name] = diff + + if not text_encoder_different: + print("Text encoder is same. Extract U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs = {} + + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + diff = diff.float() + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with svd + print("calculating by svd") + lora_weights = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + conv2d = len(mat.size()) == 4 + kernel_size = None if not conv2d else mat.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + + rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + out_dim, in_dim = mat.size()[0:2] + + if args.device: + mat = mat.to(args.device) + + # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + if conv2d: + if conv2d_3x3: + mat = mat.flatten(start_dim=1) + else: + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.reshape(out_dim, rank, 1, 1) + Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) + + U = U.to("cpu").contiguous() + Vh = Vh.to("cpu").contiguous() + + lora_weights[lora_name] = (U, Vh) + + # make state dict for LoRA + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) + + # load state dict to LoRA and save it + lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict + + info = lora_network_save.load_state_dict(lora_sd) + print(f"Loading extracted LoRA weights: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + # minimum metadata + net_kwargs = {} + if args.conv_dim is not None: + net_kwargs["conv_dim"] = args.conv_dim + net_kwargs["conv_alpha"] = args.conv_dim + + metadata = { + "ss_v2": str(args.v2), + "ss_base_model_version": model_version, + "ss_network_module": "networks.lora", + "ss_network_dim": str(args.dim), + "ss_network_alpha": str(args.dim), + "ss_network_args": json.dumps(net_kwargs), + } + + if not args.no_metadata: + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title + ) + metadata.update(sai_metadata) + + lora_network_save.save_weights(args.save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {args.save_to}") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") - parser.add_argument("--model_org", type=str, default=None, - help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") - parser.add_argument("--model_tuned", type=str, default=None, - help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") - parser.add_argument("--conv_dim", type=int, default=None, - help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - svd(args) + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") + parser.add_argument( + "--v_parameterization", + type=bool, + default=None, + help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", + ) + parser.add_argument( + "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument( + "--conv_dim", + type=int, + default=None, + help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", + ) + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(args) diff --git a/networks/lora.py b/networks/lora.py index 27f59344c..0c75cd428 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,7 +5,9 @@ import math import os -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import numpy as np import torch import re @@ -239,9 +241,13 @@ def get_mask_for_x(self, x): else: area = x.size()[1] - mask = self.network.mask_dic[area] + mask = self.network.mask_dic.get(area, None) if mask is None: - raise ValueError(f"mask is None for resolution {area}") + # raise ValueError(f"mask is None for resolution {area}") + # emb_layers in SDXL doesn't have mask + # print(f"mask is None for resolution {area}, {x.size()}") + mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) + return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: mask = torch.reshape(mask, (1, -1, 1)) return mask @@ -346,9 +352,10 @@ def to_out_forward(self, x): out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) - # for i in range(len(masks)): - # if masks[i] is None: - # masks[i] = torch.zeros_like(masks[-1]) + # if num_sub_prompts > num of LoRAs, fill with zero + for i in range(len(masks)): + if masks[i] is None: + masks[i] = torch.zeros_like(masks[0]) mask = torch.cat(masks) mask_sum = torch.sum(mask, dim=0) + 1e-4 @@ -400,7 +407,16 @@ def parse_block_lr_kwargs(nw_kwargs): return down_lr_weight, mid_lr_weight, up_lr_weight -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -719,33 +735,36 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 - # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + def __init__( self, - text_encoder, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], unet, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, - conv_lora_dim=None, - conv_alpha=None, - block_dims=None, - block_alphas=None, - conv_block_dims=None, - conv_block_alphas=None, - modules_dim=None, - modules_alpha=None, - module_class=LoRAModule, - varbose=False, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -783,8 +802,21 @@ def __init__( print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances - def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: - prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) loras = [] skipped = [] for name, module in root_module.named_modules(): @@ -800,11 +832,14 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules dim = None alpha = None + if modules_dim is not None: + # モジュール指定あり if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり block_idx = get_block_index(lora_name) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] @@ -813,6 +848,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules dim = conv_block_dims[block_idx] alpha = conv_block_alphas[block_idx] else: + # 通常、すべて対象とする if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha @@ -821,6 +857,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules alpha = self.conv_alpha if dim is None or dim == 0: + # skipした情報を出力 if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): skipped.append(lora_name) continue @@ -838,7 +875,23 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules loras.append(lora) return loras, skipped - self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights @@ -846,7 +899,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - self.unet_loras, skipped_un = create_modules(True, unet, target_modules) + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un @@ -961,6 +1014,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py new file mode 100644 index 000000000..47d75ac4d --- /dev/null +++ b/networks/lora_diffusers.py @@ -0,0 +1,609 @@ +# Diffusersで動くLoRA。このファイル単独で完結する。 +# LoRA module for Diffusers. This file works independently. + +import bisect +import math +import random +from typing import Any, Dict, List, Mapping, Optional, Union +from diffusers import UNet2DConditionModel +import numpy as np +from tqdm import tqdm +from transformers import CLIPTextModel +import torch + + +def make_unet_conversion_map() -> Dict[str, str]: + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + return sd_hf_conversion_map + + +UNET_CONVERSION_MAP = make_unet_conversion_map() + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = [org_module] + self.enabled = True + self.network: LoRANetwork = None + self.org_forward = None + + # override org_module's forward method + def apply_to(self, multiplier=None): + if multiplier is not None: + self.multiplier = multiplier + if self.org_forward is None: + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + # restore org_module's forward method + def unapply_to(self): + if self.org_forward is not None: + self.org_module[0].forward = self.org_forward + + # forward with lora + # scale is used LoRACompatibleConv, but we ignore it because we have multiplier + def forward(self, x, scale=1.0): + if not self.enabled: + return self.org_forward(x) + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def set_network(self, network): + self.network = network + + # merge lora weight to org weight + def merge_to(self, multiplier=1.0): + # get lora weight + lora_weight = self.get_weight(multiplier) + + # get org weight + org_sd = self.org_module[0].state_dict() + org_weight = org_sd["weight"] + weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype) + + # set weight to org_module + org_sd["weight"] = weight + self.org_module[0].load_state_dict(org_sd) + + # restore org weight from lora weight + def restore_from(self, multiplier=1.0): + # get lora weight + lora_weight = self.get_weight(multiplier) + + # get org weight + org_sd = self.org_module[0].state_dict() + org_weight = org_sd["weight"] + weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype) + + # set weight to org_module + org_sd["weight"] = weight + self.org_module[0].load_state_dict(org_sd) + + # return lora weight + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + +# Create network from weights for inference, weights are not loaded here +def create_network_from_weights( + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0 +): + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha[key] = modules_dim[key] + + return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) + + +def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0): + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder] + unet = pipe.unet + + lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier) + lora_network.load_state_dict(weights_sd) + lora_network.merge_to(multiplier=multiplier) + + +# block weightや学習に対応しない簡易版 / simple version without block weight and training +class LoRANetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet: UNet2DConditionModel, + multiplier: float = 1.0, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + print(f"create LoRA network from weights") + + # convert SDXL Stability AI's U-Net modules to Diffusers + converted = self.convert_unet_modules(modules_dim, modules_alpha) + if converted: + print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = ( + child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" + ) + is_conv2d = ( + child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" + ) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + if lora_name not in modules_dim: + # print(f"skipped {lora_name} (not found in modules_dim)") + skipped.append(lora_name) + continue + + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + lora = LoRAModule( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider + self.text_encoder_loras: List[LoRAModule] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + else: + index = None + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + if len(skipped_te) > 0: + print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") + + # extend U-Net target modules to include Conv2d 3x3 + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras: List[LoRAModule] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + if len(skipped_un) > 0: + print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + names.add(lora.lora_name) + for lora_name in modules_dim.keys(): + assert lora_name in names, f"{lora_name} is not found in created LoRA modules." + + # make to work load_state_dict + for lora in self.text_encoder_loras + self.unet_loras: + self.add_module(lora.lora_name, lora) + + # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers + def convert_unet_modules(self, modules_dim, modules_alpha): + converted_count = 0 + not_converted_count = 0 + + map_keys = list(UNET_CONVERSION_MAP.keys()) + map_keys.sort() + + for key in list(modules_dim.keys()): + if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): + search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "") + position = bisect.bisect_right(map_keys, search_key) + map_key = map_keys[position - 1] + if search_key.startswith(map_key): + new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key]) + modules_dim[new_key] = modules_dim[key] + modules_alpha[new_key] = modules_alpha[key] + del modules_dim[key] + del modules_alpha[key] + converted_count += 1 + else: + not_converted_count += 1 + assert ( + converted_count == 0 or not_converted_count == 0 + ), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted" + return converted_count + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + for lora in self.text_encoder_loras: + lora.apply_to(multiplier) + if apply_unet: + print("enable LoRA for U-Net") + for lora in self.unet_loras: + lora.apply_to(multiplier) + + def unapply_to(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.unapply_to() + + def merge_to(self, multiplier=1.0): + print("merge LoRA weights to original weights") + for lora in tqdm(self.text_encoder_loras + self.unet_loras): + lora.merge_to(multiplier) + print(f"weights are merged") + + def restore_from(self, multiplier=1.0): + print("restore LoRA weights from original weights") + for lora in tqdm(self.text_encoder_loras + self.unet_loras): + lora.restore_from(multiplier) + print(f"weights are restored") + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + # convert SDXL Stability AI's state dict to Diffusers' based state dict + map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules + map_keys.sort() + for key in list(state_dict.keys()): + if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): + search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "") + position = bisect.bisect_right(map_keys, search_key) + map_key = map_keys[position - 1] + if search_key.startswith(map_key): + new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key]) + state_dict[new_key] = state_dict[key] + del state_dict[key] + + # in case of V2, some weights have different shape, so we need to convert them + # because V2 LoRA is based on U-Net created by use_linear_projection=False + my_state_dict = self.state_dict() + for key in state_dict.keys(): + if state_dict[key].size() != my_state_dict[key].size(): + # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") + state_dict[key] = state_dict[key].view(my_state_dict[key].size()) + + return super().load_state_dict(state_dict, strict) + + +if __name__ == "__main__": + # sample code to use LoRANetwork + import os + import argparse + from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline + import torch + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") + parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights") + parser.add_argument("--sdxl", action="store_true", help="use SDXL model") + parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text") + parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text") + parser.add_argument("--seed", type=int, default=0, help="random seed") + args = parser.parse_args() + + image_prefix = args.model_id.replace("/", "_") + "_" + + # load Diffusers model + print(f"load model from {args.model_id}") + pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] + if args.sdxl: + # use_safetensors=True does not work with 0.18.2 + pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16) + else: + pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16) + pipe.to(device) + pipe.set_use_memory_efficient_attention_xformers(True) + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] + + # load LoRA weights + print(f"load LoRA weights from {args.lora_weights}") + if os.path.splitext(args.lora_weights)[1] == ".safetensors": + from safetensors.torch import load_file + + lora_sd = load_file(args.lora_weights) + else: + lora_sd = torch.load(args.lora_weights) + + # create by LoRA weights and load weights + print(f"create LoRA network") + lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) + + print(f"load LoRA network weights") + lora_network.load_state_dict(lora_sd) + + lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this + + # 必要があれば、元のモデルの重みをバックアップしておく + # back-up unet/text encoder weights if necessary + def detach_and_move_to_cpu(state_dict): + for k, v in state_dict.items(): + state_dict[k] = v.detach().cpu() + return state_dict + + org_unet_sd = pipe.unet.state_dict() + detach_and_move_to_cpu(org_unet_sd) + + org_text_encoder_sd = pipe.text_encoder.state_dict() + detach_and_move_to_cpu(org_text_encoder_sd) + + if args.sdxl: + org_text_encoder_2_sd = pipe.text_encoder_2.state_dict() + detach_and_move_to_cpu(org_text_encoder_2_sd) + + def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + # create image with original weights + print(f"create image with original weights") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "original.png") + + # apply LoRA network to the model: slower than merge_to, but can be reverted easily + print(f"apply LoRA network to the model") + lora_network.apply_to(multiplier=1.0) + + print(f"create image with applied LoRA") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "applied_lora.png") + + # unapply LoRA network to the model + print(f"unapply LoRA network to the model") + lora_network.unapply_to() + + print(f"create image with unapplied LoRA") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "unapplied_lora.png") + + # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) + print(f"merge LoRA network to the model") + lora_network.merge_to(multiplier=1.0) + + print(f"create image with LoRA") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "merged_lora.png") + + # restore (unmerge) LoRA weights: numerically unstable + # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない + # 保存したstate_dictから元の重みを復元するのが確実 + print(f"restore (unmerge) LoRA weights") + lora_network.restore_from(multiplier=1.0) + + print(f"create image without LoRA") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "unmerged_lora.png") + + # restore original weights + print(f"restore original weights") + pipe.unet.load_state_dict(org_unet_sd) + pipe.text_encoder.load_state_dict(org_text_encoder_sd) + if args.sdxl: + pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) + + print(f"create image with restored original weights") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "restore_original.png") + + # use convenience function to merge LoRA weights + print(f"merge LoRA weights with convenience function") + merge_lora_weights(pipe, lora_sd, multiplier=1.0) + + print(f"create image with merged LoRA weights") + seed_everything(args.seed) + image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] + image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py new file mode 100644 index 000000000..a357d7f7f --- /dev/null +++ b/networks/lora_fa.py @@ -0,0 +1,1241 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +# temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303 +# need to be refactored and merged to lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re + + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # # same as microsoft's + # torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + + # according to the paper, initialize LoRA-A (down) as normal distribution + torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim))) + + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def get_trainable_params(self): + params = self.named_parameters() + trainable_params = [] + for param in params: + if param[0] == "lora_up.weight": # up only + trainable_params.append(param[1]) + return trainable_params + + def requires_grad_(self, requires_grad: bool = True): + self.lora_up.requires_grad_(requires_grad) + self.lora_down.requires_grad_(False) + return self + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False + + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # print("default_forward", self.lora_name, x.size()) + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + + if self.network is None or self.network.sub_prompt_index is None: + return self.default_forward(x) + if not self.regional and not self.use_sub_prompt: + return self.default_forward(x) + + if self.regional: + return self.regional_forward(x) + else: + return self.sub_prompt_forward(x) + + def get_mask_for_x(self, x): + # calculate size from shape of x + if len(x.size()) == 4: + h, w = x.size()[2:4] + area = h * w + else: + area = x.size()[1] + + mask = self.network.mask_dic[area] + if mask is None: + raise ValueError(f"mask is None for resolution {area}") + if len(x.size()) != 4: + mask = torch.reshape(mask, (1, -1, 1)) + return mask + + def regional_forward(self, x): + if "attn2_to_out" in self.lora_name: + return self.to_out_forward(x) + + if self.network.mask_dic is None: # sub_prompt_index >= 3 + return self.default_forward(x) + + # apply mask for LoRA result + lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + mask = self.get_mask_for_x(lx) + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + lx = lx * mask + + x = self.org_forward(x) + x = x + lx + + if "attn2_to_q" in self.lora_name and self.network.is_last_network: + x = self.postp_to_q(x) + + return x + + def postp_to_q(self, x): + # repeat x to num_sub_prompts + has_real_uncond = x.size()[0] // self.network.batch_size == 3 + qc = self.network.batch_size # uncond + qc += self.network.batch_size * self.network.num_sub_prompts # cond + if has_real_uncond: + qc += self.network.batch_size # real_uncond + + query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) + query[: self.network.batch_size] = x[: self.network.batch_size] + + for i in range(self.network.batch_size): + qi = self.network.batch_size + i * self.network.num_sub_prompts + query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] + + if has_real_uncond: + query[-self.network.batch_size :] = x[-self.network.batch_size :] + + # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + return query + + def sub_prompt_forward(self, x): + if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA + return self.org_forward(x) + + emb_idx = self.network.sub_prompt_index + if not self.text_encoder: + emb_idx += self.network.batch_size + + # apply sub prompt of X + lx = x[emb_idx :: self.network.num_sub_prompts] + lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale + + # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + + x = self.org_forward(x) + x[emb_idx :: self.network.num_sub_prompts] += lx + + return x + + def to_out_forward(self, x): + # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + + if self.network.is_last_network: + masks = [None] * self.network.num_sub_prompts + self.network.shared[self.lora_name] = (None, masks) + else: + lx, masks = self.network.shared[self.lora_name] + + # call own LoRA + x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] + lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale + + if self.network.is_last_network: + lx = torch.zeros( + (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype + ) + self.network.shared[self.lora_name] = (lx, masks) + + # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 + masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) + + # if not last network, return x and masks + x = self.org_forward(x) + if not self.network.is_last_network: + return x + + lx, masks = self.network.shared.pop(self.lora_name) + + # if last network, combine separated x with mask weighted sum + has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 + + out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) + out[: self.network.batch_size] = x[: self.network.batch_size] # uncond + if has_real_uncond: + out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond + + # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # for i in range(len(masks)): + # if masks[i] is None: + # masks[i] = torch.zeros_like(masks[-1]) + + mask = torch.cat(masks) + mask_sum = torch.sum(mask, dim=0) + 1e-4 + for i in range(self.network.batch_size): + # 1枚の画像ごとに処理する + lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] + lx1 = lx1 * mask + lx1 = torch.sum(lx1, dim=0) + + xi = self.network.batch_size + i * self.network.num_sub_prompts + x1 = x[xi : xi + self.network.num_sub_prompts] + x1 = x1 * mask + x1 = torch.sum(x1, dim=0) + x1 = x1 / mask_sum + + x1 = x1 + lx1 + out[self.network.batch_size + i] = x1 + + # print("to_out_forward", x.size(), out.size(), has_real_uncond) + return out + + +def parse_block_lr_kwargs(nw_kwargs): + down_lr_weight = nw_kwargs.get("down_lr_weight", None) + mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) + up_lr_weight = nw_kwargs.get("up_lr_weight", None) + + # 以上のいずれにも設定がない場合は無効としてNoneを返す + if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: + return None, None, None + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + return down_lr_weight, mid_lr_weight, up_lr_weight + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + + # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする + if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + ) + + else: + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, + ) + + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + + return network + + +# このメソッドは外部から呼び出される可能性を考慮しておく +# network_dim, network_alpha にはデフォルト値が入っている。 +# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている +# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている +def get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + # block_dimsとblock_alphasをパースする。必ず値が入る + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert ( + len(block_dims) == num_total_blocks + ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + else: + print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" + else: + print( + f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" + else: + if conv_alpha is None: + conv_alpha = 1.0 + print( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + print( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +def get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold +) -> Tuple[List[float], List[float], List[float]]: + # パラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None, None, None + + max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) + base_lr for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + elif name == "zeros": + return [0.0 + base_lr] * max_len + else: + print( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: + down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len: + up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + print("apply block learning rate / 階層別学習率を適用します。") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + else: + print("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: + mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:", mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + else: + print("up_lr_weight: all 1.0, すべて1.0") + + return down_lr_weight, mid_lr_weight, up_lr_weight + + +# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく +def remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight +): + # set 0 to block dim without learning rate to remove the block + if down_lr_weight != None: + for i, lr in enumerate(down_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + if mid_lr_weight != None: + if mid_lr_weight == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if up_lr_weight != None: + for i, lr in enumerate(up_lr_weight): + if lr == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str) -> int: + block_idx = -1 # invalid lora name + + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha[key] = modules_dim[key] + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + + # block lr + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") + + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + print(f"weights are merged") + + # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない + def set_block_lr_weight( + self, + up_lr_weight: List[float] = None, + mid_lr_weight: float = None, + down_lr_weight: List[float] = None, + ): + self.block_lr = True + self.down_lr_weight = down_lr_weight + self.mid_lr_weight = mid_lr_weight + self.up_lr_weight = up_lr_weight + + def get_lr_weight(self, lora: LoRAModule) -> float: + lr_weight = 1.0 + block_idx = get_block_index(lora.lora_name) + if block_idx < 0: + return lr_weight + + if block_idx < LoRANetwork.NUM_OF_BLOCKS: + if self.down_lr_weight != None: + lr_weight = self.down_lr_weight[block_idx] + elif block_idx == LoRANetwork.NUM_OF_BLOCKS: + if self.mid_lr_weight != None: + lr_weight = self.mid_lr_weight + elif block_idx > LoRANetwork.NUM_OF_BLOCKS: + if self.up_lr_weight != None: + lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] + + return lr_weight + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras: List[LoRAModule]): + params = [] + for lora in loras: + # params.extend(lora.parameters()) + params.extend(lora.get_trainable_params()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + if self.block_lr: + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + block_idx_to_lora = {} + for lora in self.unet_loras: + idx = get_block_index(lora.lora_name) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # blockごとにパラメータを設定する + for idx, block_loras in block_idx_to_lora.items(): + param_data = {"params": enumerate_params(block_loras)} + + if unet_lr is not None: + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + elif default_lr is not None: + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + all_params.append(param_data) + + else: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + if mask.max() == 0: + mask = torch.ones_like(mask) + + self.mask = mask + self.sub_prompt_index = sub_prompt_index + self.is_last_network = is_last_network + + for lora in self.text_encoder_loras + self.unet_loras: + lora.set_network(self) + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + self.batch_size = batch_size + self.num_sub_prompts = num_sub_prompts + self.current_size = (height, width) + self.shared = shared + + # create masks + mask = self.mask + mask_dic = {} + mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w + ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight + dtype = ref_weight.dtype + device = ref_weight.device + + def resize_add(mh, mw): + # print(mh, mw, mh * mw) + m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 + m = m.to(device, dtype=dtype) + mask_dic[mh * mw] = m + + h = height // 8 + w = width // 8 + for _ in range(4): + resize_add(h, w) + if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 + resize_add(h + h % 2, w + w % 2) + h = (h + 1) // 2 + w = (w + 1) // 2 + + self.mask_dic = mask_dic + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 2fa8861bc..71492621e 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -1,8 +1,10 @@ import math import argparse import os +import time import torch from safetensors.torch import load_file, save_file +from library import sai_model_spec, train_util import library.model_util as model_util import lora @@ -10,22 +12,26 @@ def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) else: sd = torch.load(file_name, map_location="cpu") + metadata = {} + for key in list(sd.keys()): if type(sd[key]) == torch.Tensor: sd[key] = sd[key].to(dtype) - return sd + + return sd, metadata -def save_to_file(file_name, model, state_dict, dtype): +def save_to_file(file_name, model, state_dict, dtype, metadata): if dtype is not None: for key in list(state_dict.keys()): if type(state_dict[key]) == torch.Tensor: state_dict[key] = state_dict[key].to(dtype) if os.path.splitext(file_name)[1] == ".safetensors": - save_file(model, file_name) + save_file(model, file_name, metadata=metadata) else: torch.save(model, file_name) @@ -56,7 +62,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): for model, ratio in zip(models, ratios): print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) + lora_sd, _ = load_state_dict(model, merge_dtype) print(f"merging...") for key in lora_sd.keys(): @@ -81,9 +87,11 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # W <- W + U * D weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) weight = weight + ratio * (up_weight @ down_weight) * scale elif down_weight.size()[2:4] == (1, 1): # conv2d 1x1 @@ -102,14 +110,22 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): module.weight = torch.nn.Parameter(weight) -def merge_lora_models(models, ratios, merge_dtype): +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} merged_sd = {} + v2 = None + base_model = None for model, ratio in zip(models, ratios): print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if v2 is None: + v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # get alpha and dim alphas = {} # alpha for current model @@ -142,6 +158,12 @@ def merge_lora_models(models, ratios, merge_dtype): for key in lora_sd.keys(): if "alpha" in key: continue + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None lora_module_name = key[: key.rfind(".lora_")] @@ -149,12 +171,16 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: merged_sd[key] = lora_sd[key] * scale @@ -162,11 +188,37 @@ def merge_lora_models(models, ratios, merge_dtype): for lora_module_name, alpha in base_alphas.items(): key = lora_module_name + ".alpha" merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] print("merged model") print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") - return merged_sd + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata, v2 == "True" def merge(args): @@ -193,13 +245,57 @@ def str_to_dtype(p): merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, + args.v2, + args.v2, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + is_stable_diffusion_ckpt=True, + ) + if args.v2: + # TODO read sai modelspec + print( + "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + print(f"saving SD model to: {args.save_to}") - model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) + model_util.save_stable_diffusion_checkpoint( + args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae + ) else: - state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + print(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from + ) + if v2: + # TODO read sai modelspec + print( + "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + metadata.update(sai_metadata) print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: @@ -232,7 +328,25 @@ def setup_parser() -> argparse.ArgumentParser: "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" ) parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") - + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) + return parser diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py new file mode 100644 index 000000000..c513eb59f --- /dev/null +++ b/networks/sdxl_merge_lora.py @@ -0,0 +1,348 @@ +import math +import argparse +import os +import time +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from library import sai_model_spec, sdxl_model_util, train_util +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, model, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name, metadata=metadata) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): + text_encoder1.to(merge_dtype) + text_encoder1.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # print(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + v2 = None + base_model = None + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if v2 is None: + v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + print(f"merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] + + print("merged model") + print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + ( + text_model1, + text_model2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") + + merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) + + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from + ) + + print(f"saving SD model to: {args.save_to}") + sdxl_model_util.save_stable_diffusion_checkpoint( + args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype + ) + else: + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + print(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from + ) + metadata.update(sai_metadata) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 9d17efba5..16e813b36 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,10 +1,11 @@ - import math import argparse import os +import time import torch from safetensors.torch import load_file, save_file from tqdm import tqdm +from library import sai_model_spec, train_util import library.model_util as model_util import lora @@ -13,180 +14,247 @@ def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == '.safetensors': - sd = load_file(file_name) - else: - sd = torch.load(file_name, map_location='cpu') - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - return sd - + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} -def save_to_file(file_name, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) - if os.path.splitext(file_name)[1] == '.safetensors': - save_file(state_dict, file_name) - else: - torch.save(state_dict, file_name) + return sd, metadata -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") - merged_sd = {} - for model, ratio in zip(models, ratios): - print(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) - - # merge - print(f"merging...") - for key in tqdm(list(lora_sd.keys())): - if 'lora_down' not in key: - continue - - lora_module_name = key[:key.rfind(".lora_down")] - - down_weight = lora_sd[key] - network_dim = down_weight.size()[0] - - up_weight = lora_sd[lora_module_name + '.lora_up.weight'] - alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) - - in_dim = down_weight.size()[1] - out_dim = up_weight.size()[0] - conv2d = len(down_weight.size()) == 4 - kernel_size = None if not conv2d else down_weight.size()[2:4] - # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) - - # make original weight if not exist - if lora_module_name not in merged_sd: - weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) - if device: - weight = weight.to(device) - else: - weight = merged_sd[lora_module_name] - - # merge to weight - if device: - up_weight = up_weight.to(device) - down_weight = down_weight.to(device) - - # W <- W + U * D - scale = (alpha / network_dim) - - if device: # and isinstance(scale, torch.Tensor): - scale = scale.to(device) - - if not conv2d: # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif kernel_size == (1, 1): - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) - ).unsqueeze(2).unsqueeze(3) * scale - else: - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - weight = weight + ratio * conved * scale - - merged_sd[lora_module_name] = weight - - # extract from merged weights - print("extract new lora...") - merged_lora_sd = {} - with torch.no_grad(): - for lora_module_name, mat in tqdm(list(merged_sd.items())): - conv2d = (len(mat.size()) == 4) - kernel_size = None if not conv2d else mat.size()[2:4] - conv2d_3x3 = conv2d and kernel_size != (1, 1) - out_dim, in_dim = mat.size()[0:2] - - if conv2d: - if conv2d_3x3: - mat = mat.flatten(start_dim=1) - else: - mat = mat.squeeze() - - module_new_rank = new_conv_rank if conv2d_3x3 else new_rank - module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim - - U, S, Vh = torch.linalg.svd(mat) - - U = U[:, :module_new_rank] - S = S[:module_new_rank] - U = U @ torch.diag(S) - - Vh = Vh[:module_new_rank, :] - - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.reshape(out_dim, module_new_rank, 1, 1) - Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) - - up_weight = U - down_weight = Vh - - merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank) - - return merged_lora_sd +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(state_dict, file_name, metadata=metadata) + else: + torch.save(state_dict, file_name) -def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None - - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype +def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): + print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") + merged_sd = {} + v2 = None + base_model = None + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if v2 is None: + v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # merge + print(f"merging...") + for key in tqdm(list(lora_sd.keys())): + if "lora_down" not in key: + continue + + lora_module_name = key[: key.rfind(".lora_down")] + + down_weight = lora_sd[key] + network_dim = down_weight.size()[0] + + up_weight = lora_sd[lora_module_name + ".lora_up.weight"] + alpha = lora_sd.get(lora_module_name + ".alpha", network_dim) + + in_dim = down_weight.size()[1] + out_dim = up_weight.size()[0] + conv2d = len(down_weight.size()) == 4 + kernel_size = None if not conv2d else down_weight.size()[2:4] + # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + + # make original weight if not exist + if lora_module_name not in merged_sd: + weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) + if device: + weight = weight.to(device) + else: + weight = merged_sd[lora_module_name] + + # merge to weight + if device: + up_weight = up_weight.to(device) + down_weight = down_weight.to(device) + + # W <- W + U * D + scale = alpha / network_dim + + if device: # and isinstance(scale, torch.Tensor): + scale = scale.to(device) + + if not conv2d: # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif kernel_size == (1, 1): + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = weight + ratio * conved * scale + + merged_sd[lora_module_name] = weight + + # extract from merged weights + print("extract new lora...") + merged_lora_sd = {} + with torch.no_grad(): + for lora_module_name, mat in tqdm(list(merged_sd.items())): + conv2d = len(mat.size()) == 4 + kernel_size = None if not conv2d else mat.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + out_dim, in_dim = mat.size()[0:2] + + if conv2d: + if conv2d_3x3: + mat = mat.flatten(start_dim=1) + else: + mat = mat.squeeze() + + module_new_rank = new_conv_rank if conv2d_3x3 else new_rank + module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :module_new_rank] + S = S[:module_new_rank] + U = U @ torch.diag(S) + + Vh = Vh[:module_new_rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + if conv2d: + U = U.reshape(out_dim, module_new_rank, 1, 1) + Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) + + up_weight = U + down_weight = Vh + + merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() + merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() + merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) + + # build minimum metadata + dims = f"{new_rank}" + alphas = f"{new_rank}" + if new_conv_rank is not None: + network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank} + else: + network_args = None + metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args) + + return merged_lora_sd, metadata, v2 == "True", base_model - new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank - state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype) +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + state_dict, metadata, v2, base_model = merge_lora_models( + args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype + ) + + print(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + is_sdxl = base_model is not None and base_model.lower().startswith("sdxl") + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from + ) + if v2: + # TODO read sai modelspec + print( + "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + metadata.update(sai_metadata) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") - parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--models", type=str, nargs='*', - help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--ratios", type=float, nargs='*', - help="ratios for each model / それぞれのLoRAモデルの比率") - parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument("--new_conv_rank", type=int, default=None, - help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - merge(args) + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--new_conv_rank", + type=int, + default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", + ) + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/requirements.txt b/requirements.txt index debe2c789..4ca393f52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,27 @@ -accelerate==0.15.0 -transformers==4.26.0 +accelerate==0.23.0 +transformers==4.30.2 +diffusers[torch]==0.21.2 ftfy==6.1.1 -albumentations==1.3.0 +# albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.6.0 -diffusers[torch]==0.10.2 pytorch-lightning==1.9.0 -bitsandbytes==0.35.0 +# bitsandbytes==0.39.1 tensorboard==2.10.1 -safetensors==0.2.6 +safetensors==0.3.1 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 +huggingface-hub==0.15.1 # for BLIP captioning -requests==2.28.2 -timm==0.6.12 -fairscale==0.4.13 +# requests==2.28.2 +# timm==0.6.12 +# fairscale==0.4.13 # for WD14 captioning -# tensorflow<2.11 -tensorflow==2.10.1 -huggingface-hub==0.15.1 +# tensorflow==2.10.1 +# open clip for SDXL +open-clip-torch==2.20.0 # for kohya_ss library -. +-e . diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py new file mode 100755 index 000000000..ac01b76e0 --- /dev/null +++ b/sdxl_gen_img.py @@ -0,0 +1,2717 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: SdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: SdxlUNet2DConditionModel = unet + self.scheduler = scheduler + self.safety_checker = None + + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet # not supported yet + self.control_nets: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_nets: + # ControlNetのhintにguide imageを流用する + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if reginonal_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + + vector_embeddings = torch.cat([uc_vector, c_vector]) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_nets: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_nets: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable control net if ratio is set + if self.control_nets and self.control_net_enabled: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + # TODO Diffusers' ControlNet + # if self.control_nets and self.control_net_enabled: + # if reginonal_network: + # num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + # else: + # text_emb_last = text_embeddings + + # # not working yet + # noise_pred = original_control_net.call_unet_and_control_net( + # i, + # num_latent_input, + # self.unet, + # self.control_nets, + # guided_hints, + # i / len(timesteps), + # latent_model_input, + # t, + # text_emb_last, + # ).sample + # else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return latents + + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # -2 is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-2] + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-2] + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + token_replacer=None, + device=None, + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + # モデルを読み込む + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + ) + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + print("loading tokenizer") + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + print("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + mergeable = network.is_mergeable() + if args.network_merge and not mergeable: + print("network is not mergiable. ignore merge option.") + + if not args.network_merge or not mergeable: + network.apply_to([text_encoder1, text_encoder2], unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + print(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + print("backup original weights") + network.backup_weights() + + networks.append(network) + else: + network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + print("import upscaler module:", args.highres_fix_upscaler) + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + print("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[Tuple[ControlNetLLLite, float]] = [] + # if args.control_net_models: + # for i, model in enumerate(args.control_net_models): + # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + # prep = original_control_net.load_preprocess(prep_type) + # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + print(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_only(False, False) + control_nets.append((control_net, ratio)) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder1.to(memory_format=torch.channels_last) + text_encoder2.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + # cn.unet.to(memory_format=torch.channels_last) + # cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + [text_encoder1, text_encoder2], + [tokenizer1, tokenizer2], + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizer1.add_tokens(token_strings) + num_added_tokens2 = tokenizer2.add_tokens(token_strings) + assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) + token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoder1.resize_token_embeddings(len(tokenizer1)) + text_encoder2.resize_token_embeddings(len(tokenizer2)) + token_embeds1 = text_encoder1.get_input_embeddings().weight.data + token_embeds2 = text_encoder2.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + print(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' metadata") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + print(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + print(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + print("use mask as region") + + size = None + for i, network in enumerate(networks): + if i < 3: + np_mask = np.array(mask_images[0]) + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 + if args.H is None: + args.H = 1024 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + print("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompt_list[prompt_index] + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + print(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + print(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + print(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + print(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + print(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + print(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + ) + parser.add_argument( + "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") + parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") + parser.add_argument( + "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + ) + + parser.add_argument( + "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + # parser.add_argument( + # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + # ) + # parser.add_argument( + # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + # ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py new file mode 100644 index 000000000..ff865629e --- /dev/null +++ b/sdxl_minimal_inference.py @@ -0,0 +1,328 @@ +# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う +# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE + +import argparse +import datetime +import math +import os +import random +from einops import repeat +import numpy as np +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +from tqdm import tqdm +from transformers import CLIPTokenizer +from diffusers import EulerDiscreteScheduler +from PIL import Image +import open_clip +from safetensors.torch import load_file + +from library import model_util, sdxl_model_util +import networks.lora as lora + +# scheduler: このあたりの設定はSD1/2と同じでいいらしい +# scheduler: The settings around here seem to be the same as SD1/2 +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + + +# Time EmbeddingはDiffusersからのコピー +# Time Embedding is copied from Diffusers + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def get_timestep_embedding(x, outdim): + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + # x = rearrange(x, "b d -> (b d)") + x = torch.flatten(x) + emb = timestep_embedding(x, outdim) + # emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim) + emb = torch.reshape(emb, (b, dims * outdim)) + return emb + + +if __name__ == "__main__": + # 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions + + # SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL + target_height = 1024 + target_width = 1024 + original_height = target_height + original_width = target_width + crop_top = 0 + crop_left = 0 + + steps = 50 + guidance_scale = 7 + seed = None # 1 + + DEVICE = "cuda" + DTYPE = torch.float16 # bfloat16 may work + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--prompt2", type=str, default=None) + parser.add_argument("--negative_prompt", type=str, default="") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--interactive", action="store_true") + args = parser.parse_args() + + if args.prompt2 is None: + args.prompt2 = args.prompt + + # HuggingFaceのmodel id + text_encoder_1_name = "openai/clip-vit-large-patch14" + text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + # checkpointを読み込む。モデル変換についてはそちらの関数を参照 + # Load checkpoint. For model conversion, see this function + + # 本体RAMが少ない場合はGPUにロードするといいかも + # If the main RAM is small, it may be better to load it on the GPU + text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu" + ) + + # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている + # In SDXL, Text Encoder 1 is also using HuggingFace's + + # Text Encoder 2はSDXL本体ではopen_clipを使っている + # それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う + # 重みの変換コードはSD2とほぼ同じ + # In SDXL, Text Encoder 2 is using open_clip + # It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's + # The weight conversion code is almost the same as SD2 + + # VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う + # fp16でNaNが出やすいようだ + # The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different. + # NaN seems to be more likely to occur in fp16 + + unet.to(DEVICE, dtype=DTYPE) + unet.eval() + + vae_dtype = DTYPE + if DTYPE == torch.float16: + print("use float32 for vae") + vae_dtype = torch.float32 + vae.to(DEVICE, dtype=vae_dtype) + vae.eval() + + text_model1.to(DEVICE, dtype=DTYPE) + text_model1.eval() + text_model2.to(DEVICE, dtype=DTYPE) + text_model2.eval() + + unet.set_use_memory_efficient_attention(True, False) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(True) + + # Tokenizers + tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) + tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + + # LoRA + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + lora_model, weights_sd = lora.create_network_from_weights( + multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True + ) + lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE) + + # scheduler + scheduler = EulerDiscreteScheduler( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + ) + + def generate_image(prompt, prompt2, negative_prompt, seed=None): + # 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future + # prepare embedding + with torch.no_grad(): + # vector + emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) + # print("emb1", emb1.shape) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) + uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + + # crossattn + + # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders + def call_text_encoder(text, text2): + # text encoder 1 + batch_encoding = tokenizer1( + text, + truncation=True, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(DEVICE) + + with torch.no_grad(): + enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True) + text_embedding1 = enc_out["hidden_states"][11] + # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい + + # text encoder 2 + with torch.no_grad(): + tokens = tokenizer2(text2).to(DEVICE) + + enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) + text_embedding2_penu = enc_out["hidden_states"][-2] + # print("hidden_states2", text_embedding2_penu.shape) + text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion + + # 連結して終了 concat and finish + text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) + return text_embedding, text_embedding2_pool + + # cond + c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) + # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) + + # uncond + uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt) + uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1) + + text_embeddings = torch.cat([uc_ctx, c_ctx]) + vector_embeddings = torch.cat([uc_vector, c_vector]) + + # メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する + + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # # random generator for initial noise + # generator = torch.Generator(device="cuda").manual_seed(seed) + generator = None + else: + generator = None + + # get the initial random noise unless the user supplied it + # SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している + # SDXL creates latents in CPU, Diffusers creates latents in target device + latents_shape = (1, 4, target_height // 8, target_width // 8) + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=torch.float32, + ).to(DEVICE, dtype=DTYPE) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + + # set timesteps + scheduler.set_timesteps(steps, DEVICE) + + # このへんはDiffusersからのコピペ + # Copy from Diffusers + timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE) + num_latent_input = 2 + with torch.no_grad(): + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # latents = 1 / 0.18215 * latents + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + latents = latents.to(vae_dtype) + image = vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + # 保存して終了 save and finish + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + for i, img in enumerate(image): + img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png")) + + if not args.interactive: + generate_image(args.prompt, args.prompt2, args.negative_prompt, seed) + else: + # loop for interactive + while True: + prompt = input("prompt: ") + if prompt == "": + break + prompt2 = input("prompt2: ") + if prompt2 == "": + prompt2 = prompt + negative_prompt = input("negative prompt: ") + seed = input("seed: ") + if seed == "": + seed = None + else: + seed = int(seed) + generate_image(prompt, prompt2, negative_prompt, seed) + + print("Done!") diff --git a/sdxl_train.py b/sdxl_train.py new file mode 100644 index 000000000..6b255d679 --- /dev/null +++ b/sdxl_train.py @@ -0,0 +1,753 @@ +# training with captions + +import argparse +import gc +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import sdxl_model_util + +import library.train_util as train_util +import library.config_util as config_util +import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, +) +from library.sdxl_original_unet import SdxlUNet2DConditionModel + + +UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23 + + +def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]: + block_params = [[] for _ in range(len(block_lrs))] + + for i, (name, param) in enumerate(unet.named_parameters()): + if name.startswith("time_embed.") or name.startswith("label_emb."): + block_index = 0 # 0 + elif name.startswith("input_blocks."): # 1-9 + block_index = 1 + int(name.split(".")[1]) + elif name.startswith("middle_block."): # 10-12 + block_index = 10 + int(name.split(".")[1]) + elif name.startswith("output_blocks."): # 13-21 + block_index = 13 + int(name.split(".")[1]) + elif name.startswith("out."): # 22 + block_index = 22 + else: + raise ValueError(f"unexpected parameter name: {name}") + + block_params[block_index].append(param) + + params_to_optimize = [] + for i, params in enumerate(block_params): + if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0 + continue + params_to_optimize.append({"params": params, "lr": block_lrs[i]}) + + return params_to_optimize + + +def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type): + lrs = lr_scheduler.get_last_lr() + + lr_index = 0 + block_index = 0 + while lr_index < len(lrs): + if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR: + name = f"block{block_index}" + if block_lrs[block_index] == 0: + block_index += 1 + continue + elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR: + name = "text_encoder1" + elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1: + name = "text_encoder2" + else: + raise ValueError(f"unexpected block_index: {block_index}") + + block_index += 1 + + logs["lr/" + name] = float(lrs[lr_index]) + + if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): + logs["lr/d*lr/" + name] = ( + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] + ) + + lr_index += 1 + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + + assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.train_text_encoder or not args.cache_text_encoder_outputs + ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + if args.block_lr: + block_lrs = [float(lr) for lr in args.block_lr.split(",")] + assert ( + len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + else: + block_lrs = None + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + # assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります" + + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + # もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず + accelerator.print("Use xformers by Diffusers") + # set_diffusers_xformers_flag(unet, True) + set_diffusers_xformers_flag(vae, True) + else: + # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある + accelerator.print("Disable Diffusers' xformers") + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) + + if args.train_text_encoder: + # TODO each option for two text encoders? + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder1.gradient_checkpointing_enable() + text_encoder2.gradient_checkpointing_enable() + training_models.append(text_encoder1) + training_models.append(text_encoder2) + # set require_grad=True later + else: + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + text_encoder1.eval() + text_encoder2.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + for m in training_models: + m.requires_grad_(True) + + if block_lrs is None: + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params + + # calculate number of trainable parameters + n_params = 0 + for p in params: + n_params += p.numel() + else: + params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net + for m in training_models[1:]: # Text Encoders if exists + params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() + + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler + ) + + # transform DDP after prepare + text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + (unet,) = train_util.transform_models_if_DDP([unet]) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + # TODO support weighted captions + # if args.weighted_captions: + # encoder_hidden_states = get_weighted_text_embeddings( + # tokenizer, + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + # else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # # verify that the text encoder outputs are correct + # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( + # args.max_token_length, + # batch["input_ids"].to(text_encoder1.device), + # batch["input_ids2"].to(text_encoder1.device), + # tokenizer1, + # tokenizer2, + # text_encoder1, + # text_encoder2, + # None if not args.full_fp16 else weight_dtype, + # ) + # b_size = encoder_hidden_states1.shape[0] + # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # print("text encoder outputs verified") + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + + target = noise + + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss: + # do not mean over batch dimension for snr weight or scale v-pred loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # mean over batch dimension + else: + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder1), + accelerator.unwrap_model(text_encoder2), + accelerator.unwrap_model(unet), + vae, + logit_scale, + ckpt_info, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + if block_lrs is None: + logs["lr"] = float(lr_scheduler.get_last_lr()[0]) + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + else: + append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) + + accelerator.log(logs, step=global_step) + + # TODO moving averageにする + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder1), + accelerator.unwrap_model(text_encoder2), + accelerator.unwrap_model(unet), + vae, + logit_scale, + ckpt_info, + ) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + unet = accelerator.unwrap_model(unet) + text_encoder1 = accelerator.unwrap_model(text_encoder1) + text_encoder2 = accelerator.unwrap_model(text_encoder2) + + accelerator.end_training() + + if args.save_state: # and is_main_process: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + sdxl_train_util.save_sd_model_on_train_end( + args, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + global_step, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, + ) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--block_lr", + type=str, + default=None, + help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py new file mode 100644 index 000000000..61ebfb581 --- /dev/null +++ b/sdxl_train_control_net_lllite.py @@ -0,0 +1,609 @@ +# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード +# training code for ControlNet-LLLite with passing cond_image to U-Net's forward + +import argparse +import gc +import json +import math +import os +import random +import time +from multiprocessing import Value +from types import SimpleNamespace +import toml + +from tqdm import tqdm +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +import accelerate +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file +from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +import networks.control_net_lllite_for_train as control_net_lllite_for_train + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # prepare ControlNet-LLLite + control_net_lllite_for_train.replace_unet_linear_and_conv2d() + + if args.network_weights is not None: + accelerator.print(f"initialize U-Net with ControlNet-LLLite") + with accelerate.init_empty_weights(): + unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() + unet_lllite.to(accelerator.device, dtype=weight_dtype) + + unet_sd = unet.state_dict() + info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd) + accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}") + else: + # cosumes large memory, so send to GPU before creating the LLLite model + accelerator.print("sending U-Net to GPU") + unet.to(accelerator.device, dtype=weight_dtype) + unet_sd = unet.state_dict() + + # init LLLite weights + accelerator.print(f"initialize U-Net with ControlNet-LLLite") + + if args.lowram: + with accelerate.init_on_device(accelerator.device): + unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() + else: + unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() + unet_lllite.to(weight_dtype) + + info = unet_lllite.load_lllite_weights(None, unet_sd) + accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}") + del unet_sd, unet + + unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite + del unet_lllite + + unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(unet.prepare_params()) + print(f"trainable params count: {len(trainable_params)}") + print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + # if args.full_fp16: + # assert ( + # args.mixed_precision == "fp16" + # ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + # accelerator.print("enable full fp16 training.") + # unet.to(weight_dtype) + # elif args.full_bf16: + # assert ( + # args.mixed_precision == "bf16" + # ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + # accelerator.print("enable full bf16 training.") + # unet.to(weight_dtype) + + unet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + # transform DDP after prepare (train_network here only) + unet = train_util.transform_models_if_DDP([unet])[0] + + if args.gradient_checkpointing: + unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + unet.eval() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model( + ckpt_name, + unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite, + steps, + epoch_no, + force_sync_upload=False, + ): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" + + unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(unet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet + # 内部でcond_embに変換される / it will be converted to cond_emb inside + + # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = unet.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + if is_main_process: + unet = accelerator.unwrap_model(unet) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) + + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py new file mode 100644 index 000000000..f8169bdbf --- /dev/null +++ b/sdxl_train_control_net_lllite_old.py @@ -0,0 +1,579 @@ +import argparse +import gc +import json +import math +import os +import random +import time +from multiprocessing import Value +from types import SimpleNamespace +import toml + +from tqdm import tqdm +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file +from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +import networks.control_net_lllite as control_net_lllite + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # prepare ControlNet + network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) + network.apply_to() + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(network.prepare_optimizer_params()) + print(f"trainable params count: {len(trainable_params)}") + print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + network: control_net_lllite.ControlNetLLLite + + # transform DDP after prepare (train_network here only) + unet, network = train_util.transform_models_if_DDP([unet, network]) + + if args.gradient_checkpointing: + unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + unet.eval() + + network.prepare_grad_etc() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" + + unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + network.on_epoch_start() # train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet + # 内部でcond_embに変換される / it will be converted to cond_emb inside + network.set_cond_image(controlnet_image) + + # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + if is_main_process: + network = accelerator.unwrap_model(network) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_network.py b/sdxl_train_network.py new file mode 100644 index 000000000..2de57c0ac --- /dev/null +++ b/sdxl_train_network.py @@ -0,0 +1,183 @@ +import argparse +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +from library import sdxl_model_util, sdxl_train_util, train_util +import train_network + + +class SdxlNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + self.is_sdxl = True + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet + + def load_tokenizer(self, args): + tokenizer = sdxl_train_util.load_tokenizers(args) + return tokenizer + + def is_text_encoder_outputs_cached(self, args): + return args.cache_text_encoder_outputs + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + print("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + dataset.cache_text_encoder_outputs( + tokenizers, + text_encoders, + accelerator.device, + weight_dtype, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + + text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + text_encoders[1].to("cpu", dtype=torch.float32) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if not args.lowram: + print("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device) + text_encoders[1].to(accelerator.device) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.enable_grad(): + # Get the text embedding for conditioning + # TODO support weighted captions + # if args.weighted_captions: + # encoder_hidden_states = get_weighted_text_embeddings( + # tokenizer, + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + # else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # # verify that the text encoder outputs are correct + # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( + # args.max_token_length, + # batch["input_ids"].to(text_encoders[0].device), + # batch["input_ids2"].to(text_encoders[0].device), + # tokenizers[0], + # tokenizers[1], + # text_encoders[0], + # text_encoders[1], + # None if not args.full_fp16 else weight_dtype, + # ) + # b_size = encoder_hidden_states1.shape[0] + # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # print("text encoder outputs verified") + + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlNetworkTrainer() + trainer.train(args) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py new file mode 100644 index 000000000..f5cca17b2 --- /dev/null +++ b/sdxl_train_textual_inversion.py @@ -0,0 +1,140 @@ +import argparse +import os + +import regex +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +import open_clip +from library import sdxl_model_util, sdxl_train_util, train_util + +import train_textual_inversion + + +class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + self.is_sdxl = True + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet + + def load_tokenizer(self, args): + tokenizer = sdxl_train_util.load_tokenizers(args) + return tokenizer + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.enable_grad(): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + ) + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + sdxl_train_util.sample_images( + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + def save_weights(self, file, updated_embs, save_dtype, metadata): + state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + data = torch.load(file, map_location="cpu") + + emb_l = data.get("clip_l", None) # ViT-L text encoder 1 + emb_g = data.get("clip_g", None) # BiG-G text encoder 2 + + assert ( + emb_l is not None or emb_g is not None + ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" + + return [emb_l, emb_g] + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_textual_inversion.setup_parser() + # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching + # sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlTextualInversionTrainer() + trainer.train(args) diff --git a/tools/cache_latents.py b/tools/cache_latents.py new file mode 100644 index 000000000..b6991ac19 --- /dev/null +++ b/tools/cache_latents.py @@ -0,0 +1,194 @@ +# latentsのdiskへの事前キャッシュを行う / cache latents to disk + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import config_util +from library import train_util +from library import sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache latents arg + assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + if args.sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + # datasetのcache_latentsを呼ばなければ、生の画像が返る + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + print("load model") + if args.sdxl: + (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + else: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("latents") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + b_size = len(batch["images"]) + vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size + flip_aug = batch["flip_aug"] + random_crop = batch["random_crop"] + bucket_reso = batch["bucket_reso"] + + # バッチを分割して処理する + for i in range(0, b_size, vae_batch_size): + images = batch["images"][i : i + vae_batch_size] + absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] + resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] + + image_infos = [] + for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.image = image + image_info.bucket_reso = bucket_reso + image_info.resized_size = resized_size + image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" + + if args.skip_existing: + if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + print(f"Skipping {image_info.latents_npz} because it already exists.") + continue + + image_infos.append(image_info) + + if len(image_infos) > 0: + train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py new file mode 100644 index 000000000..2110e7261 --- /dev/null +++ b/tools/cache_text_encoder_outputs.py @@ -0,0 +1,191 @@ +# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import config_util +from library import train_util +from library import sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache arg + assert ( + args.cache_text_encoder_outputs_to_disk + ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" + + # できるだけ準備はしておくが今のところSDXLのみしか動かない + assert ( + args.sdxl + ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + if args.sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + + # モデルを読み込む + print("load model") + if args.sdxl: + (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + text_encoders = [text_encoder1, text_encoder2] + else: + text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + text_encoders = [text_encoder1] + + for text_encoder in text_encoders: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("text") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + absolute_paths = batch["absolute_paths"] + input_ids1_list = batch["input_ids1_list"] + input_ids2_list = batch["input_ids2_list"] + + image_infos = [] + for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + image_info + + if args.skip_existing: + if os.path.exists(image_info.text_encoder_outputs_npz): + print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + continue + + image_info.input_ids1 = input_ids1 + image_info.input_ids2 = input_ids2 + image_infos.append(image_info) + + if len(image_infos) > 0: + b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) + b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) + train_util.cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype + ) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args) diff --git a/tools/merge_models.py b/tools/merge_models.py new file mode 100644 index 000000000..391bfe677 --- /dev/null +++ b/tools/merge_models.py @@ -0,0 +1,168 @@ +import argparse +import os + +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm + + +def is_unet_key(key): + # VAE or TextEncoder, the last one is for SDXL + return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key) + + +TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), +] + + +# support for models with different text encoder keys +def replace_text_encoder_key(key): + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + if key.startswith(rep_from): + return True, rep_to + key[len(rep_from) :] + return False, key + + +def merge(args): + if args.precision == "fp16": + dtype = torch.float16 + elif args.precision == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float + + if args.saving_precision == "fp16": + save_dtype = torch.float16 + elif args.saving_precision == "bf16": + save_dtype = torch.bfloat16 + else: + save_dtype = torch.float + + # check if all models are safetensors + for model in args.models: + if not model.endswith("safetensors"): + print(f"Model {model} is not a safetensors model") + exit() + if not os.path.isfile(model): + print(f"Model {model} does not exist") + exit() + + assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" + + # load and merge + ratio = 1.0 / len(args.models) # default + supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later + + merged_sd = None + first_model_keys = set() # check missing keys in other models + for i, model in enumerate(args.models): + if args.ratios is not None: + ratio = args.ratios[i] + + if merged_sd is None: + # load first model + print(f"Loading model {model}, ratio = {ratio}...") + merged_sd = {} + with safe_open(model, framework="pt", device=args.device) as f: + for key in tqdm(f.keys()): + value = f.get_tensor(key) + _, key = replace_text_encoder_key(key) + + first_model_keys.add(key) + + if not is_unet_key(key) and args.unet_only: + supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder + continue + + value = ratio * value.to(dtype) # first model's value * ratio + merged_sd[key] = value + + print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + continue + + # load other models + print(f"Loading model {model}, ratio = {ratio}...") + + with safe_open(model, framework="pt", device=args.device) as f: + model_keys = f.keys() + for key in tqdm(model_keys): + _, new_key = replace_text_encoder_key(key) + if new_key not in merged_sd: + if args.show_skipped and new_key not in first_model_keys: + print(f"Skip: {new_key}") + continue + + value = f.get_tensor(key) + merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype) + + # enumerate keys not in this model + model_keys = set(model_keys) + for key in merged_sd.keys(): + if key in model_keys: + continue + print(f"Key {key} not in model {model}, use first model's value") + if key in supplementary_key_ratios: + supplementary_key_ratios[key] += ratio + else: + supplementary_key_ratios[key] = ratio + + # add supplementary keys' value (including VAE and TextEncoder) + if len(supplementary_key_ratios) > 0: + print("add first model's value") + with safe_open(args.models[0], framework="pt", device=args.device) as f: + for key in tqdm(f.keys()): + _, new_key = replace_text_encoder_key(key) + if new_key not in supplementary_key_ratios: + continue + + if is_unet_key(new_key): # not VAE or TextEncoder + print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + + value = f.get_tensor(key) # original key + + if new_key not in merged_sd: + merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype) + else: + merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype) + + # save + output_file = args.output + if not output_file.endswith(".safetensors"): + output_file = output_file + ".safetensors" + + print(f"Saving to {output_file}...") + + # convert to save_dtype + for k in merged_sd.keys(): + merged_sd[k] = merged_sd[k].to(save_dtype) + + save_file(merged_sd, output_file) + + print("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Merge models") + parser.add_argument("--models", nargs="+", type=str, help="Models to merge") + parser.add_argument("--output", type=str, help="Output model") + parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0") + parser.add_argument("--unet_only", action="store_true", help="Only merge unet") + parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu") + parser.add_argument( + "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float" + ) + parser.add_argument( + "--saving_precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="Saving precision, default is float", + ) + parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)") + + args = parser.parse_args() + merge(args) diff --git a/tools/original_control_net.py b/tools/original_control_net.py index 582794de7..cd47bd76a 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -4,176 +4,187 @@ import torch from safetensors.torch import load_file -from diffusers import UNet2DConditionModel -from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util class ControlNetInfo(NamedTuple): - unet: Any - net: Any - prep: Any - weight: float - ratio: float + unet: Any + net: Any + prep: Any + weight: float + ratio: float class ControlNet(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - # make control model - self.control_model = torch.nn.Module() - - dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] - zero_convs = torch.nn.ModuleList() - for i, dim in enumerate(dims): - sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) - zero_convs.append(sub_list) - self.control_model.add_module("zero_convs", zero_convs) - - middle_block_out = torch.nn.Conv2d(1280, 1280, 1) - self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) - - dims = [16, 16, 32, 32, 96, 96, 256, 320] - strides = [1, 1, 2, 1, 2, 1, 2, 1] - prev_dim = 3 - input_hint_block = torch.nn.Sequential() - for i, (dim, stride) in enumerate(zip(dims, strides)): - input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) - if i < len(dims) - 1: - input_hint_block.append(torch.nn.SiLU()) - prev_dim = dim - self.control_model.add_module("input_hint_block", input_hint_block) + def __init__(self) -> None: + super().__init__() + + # make control model + self.control_model = torch.nn.Module() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] + zero_convs = torch.nn.ModuleList() + for i, dim in enumerate(dims): + sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) + zero_convs.append(sub_list) + self.control_model.add_module("zero_convs", zero_convs) + + middle_block_out = torch.nn.Conv2d(1280, 1280, 1) + self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) + + dims = [16, 16, 32, 32, 96, 96, 256, 320] + strides = [1, 1, 2, 1, 2, 1, 2, 1] + prev_dim = 3 + input_hint_block = torch.nn.Sequential() + for i, (dim, stride) in enumerate(zip(dims, strides)): + input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) + if i < len(dims) - 1: + input_hint_block.append(torch.nn.SiLU()) + prev_dim = dim + self.control_model.add_module("input_hint_block", input_hint_block) def load_control_net(v2, unet, model): - device = unet.device - - # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む - # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") - - if model_util.is_safetensors(model): - ctrl_sd_sd = load_file(model) - else: - ctrl_sd_sd = torch.load(model, map_location='cpu') - ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) - - # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む - is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) - - # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく - # またTransfer Controlの元weightとなる - ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) - - # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける - for key in list(ctrl_unet_sd_sd.keys()): - ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() - - zero_conv_sd = {} - for key in list(ctrl_sd_sd.keys()): - if key.startswith("control_"): - unet_key = "model.diffusion_" + key[len("control_"):] - if unet_key not in ctrl_unet_sd_sd: # zero conv - zero_conv_sd[key] = ctrl_sd_sd[key] - continue - if is_difference: # Transfer Control - ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) - else: - ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) - - unet_config = model_util.create_unet_diffusers_config(v2) - ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict - - # ControlNetのU-Netを作成する - ctrl_unet = UNet2DConditionModel(**unet_config) - info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) - - # U-Net以外のControlNetを作成する - # TODO support middle only - ctrl_net = ControlNet() - info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) - - ctrl_unet.to(unet.device, dtype=unet.dtype) - ctrl_net.to(unet.device, dtype=unet.dtype) - return ctrl_unet, ctrl_net + device = unet.device + # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む + # state dictを読み込む + print(f"ControlNet: loading control SD model : {model}") -def load_preprocess(prep_type: str): - if prep_type is None or prep_type.lower() == "none": - return None + if model_util.is_safetensors(model): + ctrl_sd_sd = load_file(model) + else: + ctrl_sd_sd = torch.load(model, map_location="cpu") + ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) + + # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む + is_difference = "difference" in ctrl_sd_sd + print("ControlNet: loading difference:", is_difference) + + # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく + # またTransfer Controlの元weightとなる + ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) + + # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける + for key in list(ctrl_unet_sd_sd.keys()): + ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() + + zero_conv_sd = {} + for key in list(ctrl_sd_sd.keys()): + if key.startswith("control_"): + unet_key = "model.diffusion_" + key[len("control_") :] + if unet_key not in ctrl_unet_sd_sd: # zero conv + zero_conv_sd[key] = ctrl_sd_sd[key] + continue + if is_difference: # Transfer Control + ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) + else: + ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) + + unet_config = model_util.create_unet_diffusers_config(v2) + ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict + + # ControlNetのU-Netを作成する + ctrl_unet = UNet2DConditionModel(**unet_config) + info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) + print("ControlNet: loading Control U-Net:", info) + + # U-Net以外のControlNetを作成する + # TODO support middle only + ctrl_net = ControlNet() + info = ctrl_net.load_state_dict(zero_conv_sd) + print("ControlNet: loading ControlNet:", info) + + ctrl_unet.to(unet.device, dtype=unet.dtype) + ctrl_net.to(unet.device, dtype=unet.dtype) + return ctrl_unet, ctrl_net - if prep_type.startswith("canny"): - args = prep_type.split("_") - th1 = int(args[1]) if len(args) >= 2 else 63 - th2 = int(args[2]) if len(args) >= 3 else 191 - def canny(img): - img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - return cv2.Canny(img, th1, th2) - return canny +def load_preprocess(prep_type: str): + if prep_type is None or prep_type.lower() == "none": + return None - print("Unsupported prep type:", prep_type) - return None + if prep_type.startswith("canny"): + args = prep_type.split("_") + th1 = int(args[1]) if len(args) >= 2 else 63 + th2 = int(args[2]) if len(args) >= 3 else 191 + def canny(img): + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + return cv2.Canny(img, th1, th2) -def preprocess_ctrl_net_hint_image(image): - image = np.array(image).astype(np.float32) / 255.0 - # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている - # image = image[:, :, ::-1].copy() # rgb to bgr - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 + return canny + print("Unsupported prep type:", prep_type) + return None -def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): - guided_hints = [] - for i, cnet_info in enumerate(control_nets): - # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること - b_hints = [] - if len(hints) == 1: # すべて同じ画像をhintとして使う - hint = hints[0] - if cnet_info.prep is not None: - hint = cnet_info.prep(hint) - hint = preprocess_ctrl_net_hint_image(hint) - b_hints = [hint for _ in range(b_size)] - else: - for bi in range(b_size): - hint = hints[(bi * len(control_nets) + i) % len(hints)] - if cnet_info.prep is not None: - hint = cnet_info.prep(hint) - hint = preprocess_ctrl_net_hint_image(hint) - b_hints.append(hint) - b_hints = torch.cat(b_hints, dim=0) - b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) - guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) - guided_hints.append(guided_hint) - return guided_hints +def preprocess_ctrl_net_hint_image(image): + image = np.array(image).astype(np.float32) / 255.0 + # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている + # image = image[:, :, ::-1].copy() # rgb to bgr + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 -def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): - # ControlNet - # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する - cnet_cnt = len(control_nets) - cnet_idx = step % cnet_cnt - cnet_info = control_nets[cnet_idx] +def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): + guided_hints = [] + for i, cnet_info in enumerate(control_nets): + # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること + b_hints = [] + if len(hints) == 1: # すべて同じ画像をhintとして使う + hint = hints[0] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints = [hint for _ in range(b_size)] + else: + for bi in range(b_size): + hint = hints[(bi * len(control_nets) + i) % len(hints)] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints.append(hint) + b_hints = torch.cat(b_hints, dim=0) + b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) + + guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) + guided_hints.append(guided_hint) + return guided_hints + + +def call_unet_and_control_net( + step, + num_latent_input, + original_unet, + control_nets: List[ControlNetInfo], + guided_hints, + current_ratio, + sample, + timestep, + encoder_hidden_states, + encoder_hidden_states_for_control_net, +): + # ControlNet + # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する + cnet_cnt = len(control_nets) + cnet_idx = step % cnet_cnt + cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) - if cnet_info.ratio < current_ratio: - return original_unet(sample, timestep, encoder_hidden_states) + # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + if cnet_info.ratio < current_ratio: + return original_unet(sample, timestep, encoder_hidden_states) - guided_hint = guided_hints[cnet_idx] - guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) - outs = [o * cnet_info.weight for o in outs] + guided_hint = guided_hints[cnet_idx] + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net) + outs = [o * cnet_info.weight for o in outs] - # U-Net - return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) """ @@ -204,118 +215,123 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net """ -def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): - # copy from UNet2DConditionModel - default_overall_up_factor = 2**unet.num_upsamplers - - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 0. center input if necessary - if unet.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = unet.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=unet.dtype) - emb = unet.time_embedding(t_emb) - - outs = [] # output of ControlNet - zc_idx = 0 - - # 2. pre-process - sample = unet.conv_in(sample) - if is_control_net: - sample += guided_hint - outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) - zc_idx += 1 - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in unet.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) +def unet_forward( + is_control_net, + control_net: ControlNet, + unet: UNet2DConditionModel, + guided_hint, + ctrl_outs, + sample, + timestep, + encoder_hidden_states, +): + # copy from UNet2DConditionModel + default_overall_up_factor = 2**unet.num_upsamplers + + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = unet.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=unet.dtype) + emb = unet.time_embedding(t_emb) + + outs = [] # output of ControlNet + zc_idx = 0 + + # 2. pre-process + sample = unet.conv_in(sample) if is_control_net: - for rs in res_samples: - outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + sample += guided_hint + outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) zc_idx += 1 - down_block_res_samples += res_samples - - # 4. mid - sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - if is_control_net: - outs.append(control_net.control_model.middle_block_out[0](sample)) - return outs - - if not is_control_net: - sample += ctrl_outs.pop() - - # 5. up - for i, upsample_block in enumerate(unet.up_blocks): - is_final_block = i == len(unet.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets):] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if not is_control_net and len(ctrl_outs) > 0: - res_samples = list(res_samples) - apply_ctrl_outs = ctrl_outs[-len(res_samples):] - ctrl_outs = ctrl_outs[:-len(res_samples)] - for j in range(len(res_samples)): - res_samples[j] = res_samples[j] + apply_ctrl_outs[j] - res_samples = tuple(res_samples) - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = unet.conv_norm_out(sample) - sample = unet.conv_act(sample) - sample = unet.conv_out(sample) - - return UNet2DConditionOutput(sample=sample) + # 3. down + down_block_res_samples = (sample,) + for downsample_block in unet.down_blocks: + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_control_net: + for rs in res_samples: + outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + zc_idx += 1 + + down_block_res_samples += res_samples + + # 4. mid + sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + if is_control_net: + outs.append(control_net.control_model.middle_block_out[0](sample)) + return outs + + if not is_control_net: + sample += ctrl_outs.pop() + + # 5. up + for i, upsample_block in enumerate(unet.up_blocks): + is_final_block = i == len(unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if not is_control_net and len(ctrl_outs) > 0: + res_samples = list(res_samples) + apply_ctrl_outs = ctrl_outs[-len(res_samples) :] + ctrl_outs = ctrl_outs[: -len(res_samples)] + for j in range(len(res_samples)): + res_samples[j] = res_samples[j] + apply_ctrl_outs[j] + res_samples = tuple(res_samples) + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = unet.conv_norm_out(sample) + sample = unet.conv_act(sample) + sample = unet.conv_out(sample) + + return SampleOutput(sample=sample) diff --git a/tools/show_metadata.py b/tools/show_metadata.py new file mode 100644 index 000000000..92ca7b1c8 --- /dev/null +++ b/tools/show_metadata.py @@ -0,0 +1,19 @@ +import json +import argparse +from safetensors import safe_open + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, required=True) +args = parser.parse_args() + +with safe_open(args.model, framework="pt") as f: + metadata = f.metadata() + +if metadata is None: + print("No metadata found") +else: + # metadata is json dict, but not pretty printed + # sort by key and pretty print + print(json.dumps(metadata, indent=4, sort_keys=True)) + + \ No newline at end of file diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 000000000..42da44125 --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,611 @@ +import argparse +import gc +import json +import math +import os +import random +import time +from multiprocessing import Value +from types import SimpleNamespace +import toml + +from tqdm import tqdm +import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + pyramid_noise_like, + apply_noise_offset, +) + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + # session_id = random.randint(0, 2**32) + # training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model( + args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True + ) + + # DiffusersのControlNetが使用するデータを準備する + if args.v2: + unet.config = { + "act_fn": "silu", + "attention_head_dim": [5, 10, 20, 20], + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 1024, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "sample_size": 96, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "use_linear_projection": True, + "upcast_attention": True, + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": True, + "class_embed_type": None, + "num_class_embeds": None, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + else: + unet.config = { + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 768, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": False, + "class_embed_type": None, + "num_class_embeds": None, + "upcast_attention": False, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + unet.config = SimpleNamespace(**unet.config) + + controlnet = ControlNetModel.from_unet(unet) + + if args.controlnet_model_name_or_path: + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = controlnet.parameters() + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + controlnet.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.to(accelerator.device) + text_encoder.to(accelerator.device) + + # transform DDP after prepare + controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet + + controlnet.train() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + if is_main_process: + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, + latents.device, + args.multires_noise_iterations, + args.multires_noise_discount, + ) + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (b_size,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model( + ckpt_name, + accelerator.unwrap_model(controlnet), + ) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # end of epoch + if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, controlnet, force_sync_upload=True) + + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_db.py b/train_db.py index 927e79dea..feb147787 100644 --- a/train_db.py +++ b/train_db.py @@ -2,18 +2,23 @@ # XXX dropped option: fine_tune import gc -import time import argparse import itertools import math import os -import toml from multiprocessing import Value +import toml from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -48,7 +53,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) if args.dataset_config is not None: print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -99,7 +104,7 @@ def train(args): f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -123,7 +128,7 @@ def train(args): use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: @@ -144,7 +149,7 @@ def train(args): unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) if not train_text_encoder: - print("Text Encoder is not trained.") + accelerator.print("Text Encoder is not trained.") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -156,12 +161,13 @@ def train(args): vae.to(accelerator.device, dtype=weight_dtype) # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") if train_text_encoder: - trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) + # wightout list, adamw8bit is crashed + trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) else: trainable_params = unet.parameters() - + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する @@ -181,7 +187,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -197,7 +203,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") unet.to(weight_dtype) text_encoder.to(weight_dtype) @@ -230,15 +236,17 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -247,14 +255,19 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) loss_list = [] loss_total = 0.0 for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 @@ -267,7 +280,7 @@ def train(args): current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: - print(f"stop text encoder training at step {global_step}") + accelerator.print(f"stop text encoder training at step {global_step}") if not args.gradient_checkpointing: text_encoder.train(False) text_encoder.requires_grad_(False) @@ -282,15 +295,6 @@ def train(args): latents = latents * 0.18215 b_size = latents.shape[0] - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - # elif args.perlin_noise: - # noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently - # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: @@ -308,13 +312,9 @@ def train(args): args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype ) - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -376,15 +376,17 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) @@ -424,8 +426,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -433,8 +435,8 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/train_network.py b/train_network.py index da0ca1c9c..1a1713259 100644 --- a/train_network.py +++ b/train_network.py @@ -3,16 +3,28 @@ import gc import math import os +import sys import random import time import json -import toml from multiprocessing import Value +import toml from tqdm import tqdm import torch + +try: + import intel_extension_for_pytorch as ipex + + if torch.xpu.is_available(): + from library.ipex import ipex_init + + ipex_init() +except Exception: + pass from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library import model_util import library.train_util as train_util from library.train_util import ( @@ -29,762 +41,877 @@ apply_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, ) -# TODO 他のスクリプトと共通化する -def generate_step_logs( - args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None -): - logs = {"loss/current": current_loss, "loss/average": avr_loss} - - if keys_scaled is not None: - logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm - logs["max_norm/max_key_norm"] = maximum_norm - - lrs = lr_scheduler.get_last_lr() - - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet. - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - - return logs - - -def train(args): - session_id = random.randint(0, 2**32) - training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) +class NetworkTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 + self.is_sdxl = False - cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None - use_user_config = args.dataset_config is not None + # TODO 他のスクリプトと共通化する + def generate_step_logs( + self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + ): + logs = {"loss/current": current_loss, "loss/average": avr_loss} - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) + if keys_scaled is not None: + logs["max_norm/keys_scaled"] = keys_scaled + logs["max_norm/average_key_norm"] = mean_norm + logs["max_norm/max_key_norm"] = maximum_norm - tokenizer = train_util.load_tokenizer(args) + lrs = lr_scheduler.get_last_lr() - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) - if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) + if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) + if args.network_train_unet_only: + logs["lr/unet"] = float(lrs[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lrs[0]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] ) else: - if use_dreambooth_method: - print("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + + return logs + + def assert_extra_args(self, args, train_dataset_group): + pass + + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + + def load_tokenizer(self, args): + tokenizer = train_util.load_tokenizer(args) + return tokenizer + + def is_text_encoder_outputs_cached(self, args): + return False + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype + ): + for t_enc in text_encoders: + t_enc.to(accelerator.device) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) + return encoder_hidden_states + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noise_pred = unet(noisy_latents, timesteps, text_conds).sample + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + + def train(self, args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため + tokenizer = self.load_tokenizer(args) + tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if use_user_config: + print(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - print("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - print( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) - # acceleratorを準備する - print("preparing accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process + # acceleratorを準備する + print("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + # モデルを読み込む + model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + # text_encoder is List[CLIPTextModel] or CLIPTextModel + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # 差分追加学習のためにモデルを読み込む - import sys + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) + # 差分追加学習のためにモデルを読み込む + sys.path.append(os.path.dirname(__file__)) + accelerator.print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) - if args.base_weights is not None: - # base_weights が指定されている場合は、指定された重みを読み込みマージする - for i, weight_path in enumerate(args.base_weights): - if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: - multiplier = 1.0 - else: - multiplier = args.base_weights_multiplier[i] + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] - print(f"merging module: {weight_path} with multiplier {multiplier}") + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") - module, weights_sd = network_module.create_network_from_weights( - multiplier, weight_path, vae, text_encoder, unet, for_inference=True - ) - module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - - print(f"all weights merged: {', '.join(args.base_weights)}") - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - accelerator.wait_for_everyone() - - # prepare network - net_kwargs = {} - if args.network_args is not None: - for net_arg in args.network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - # if a new network is added in future, add if ~ then blocks for each network (;'∀') - if args.dim_from_weights: - network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) - else: - # LyCORIS will work with this... - network = network_module.create_network( - 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs - ) - if network is None: - return - - if hasattr(network, "prepare_network"): - network.prepare_network(args) - if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( - "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" - ) - args.scale_weight_norms = False - - train_unet = not args.network_train_text_encoder_only - train_text_encoder = not args.network_train_unet_only - network.apply_to(text_encoder, unet, train_text_encoder, train_unet) - - if args.network_weights is not None: - info = network.load_weights(args.network_weights) - print(f"loaded network weights from {args.network_weights}: {info}") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - network.enable_gradient_checkpointing() # may have no effect - - # 学習に必要なクラスを準備する - print("preparing optimizer, data loader etc.") - - # 後方互換性を確保するよ - try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) - except TypeError: - print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + module, weights_sd = network_module.create_network_from_weights( + multiplier, weight_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + self.cache_text_encoder_outputs_if_needed( + args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + # prepare network + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + if args.dim_from_weights: + network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) + else: + # LyCORIS will work with this... + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + vae, + text_encoder, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + if network is None: + return + + if hasattr(network, "prepare_network"): + network.prepare_network(args) + if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): + print( + "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" + ) + args.scale_weight_norms = False + + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + accelerator.print(f"load network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + for t_enc in text_encoders: + t_enc.gradient_checkpointing_enable() + del t_enc + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + # 後方互換性を確保するよ + try: + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + except TypeError: + accelerator.print( + "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + ) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collater, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - if is_main_process: - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enabling full fp16 training.") - network.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_unet and train_text_encoder: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) - elif train_text_encoder: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, ) - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) - - # transform DDP after prepare (train_network here only) - text_encoder, unet, network = train_util.transform_if_model_is_DDP(text_encoder, unet, network) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - text_encoder.train() - - # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) - else: - unet.eval() - text_encoder.eval() - - network.prepare_grad_etc(text_encoder, unet) - - if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - if is_main_process: - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - # TODO refactor metadata creation and move to util - metadata = { - "ss_session_id": session_id, # random integer indicating which group of epochs the model came from - "ss_training_started_at": training_started_at, # unix timestamp - "ss_output_name": args.output_name, - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset_group.num_train_images, - "ss_num_reg_images": train_dataset_group.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_gradient_checkpointing": args.gradient_checkpointing, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim - "ss_network_alpha": args.network_alpha, # some networks may not have alpha - "ss_network_dropout": args.network_dropout, # some networks may not have dropout - "ss_mixed_precision": args.mixed_precision, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_cache_latents": bool(args.cache_latents), - "ss_seed": args.seed, - "ss_lowram": args.lowram, - "ss_noise_offset": args.noise_offset, - "ss_multires_noise_iterations": args.multires_noise_iterations, - "ss_multires_noise_discount": args.multires_noise_discount, - "ss_adaptive_noise_scale": args.adaptive_noise_scale, - "ss_training_comment": args.training_comment, # will not be updated after training - "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), - "ss_max_grad_norm": args.max_grad_norm, - "ss_caption_dropout_rate": args.caption_dropout_rate, - "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, - "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, - "ss_face_crop_aug_range": args.face_crop_aug_range, - "ss_prior_loss_weight": args.prior_loss_weight, - "ss_min_snr_gamma": args.min_snr_gamma, - "ss_scale_weight_norms": args.scale_weight_norms, - } - - if use_user_config: - # save metadata of multiple datasets - # NOTE: pack "ss_datasets" value as json one time - # or should also pack nested collections as json? - datasets_metadata = [] - tag_frequency = {} # merge tag frequency for metadata editor - dataset_dirs_info = {} # merge subset dirs for metadata editor - - for dataset in train_dataset_group.datasets: - is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) - dataset_metadata = { - "is_dreambooth": is_dreambooth_dataset, - "batch_size_per_device": dataset.batch_size, - "num_train_images": dataset.num_train_images, # includes repeating - "num_reg_images": dataset.num_reg_images, - "resolution": (dataset.width, dataset.height), - "enable_bucket": bool(dataset.enable_bucket), - "min_bucket_reso": dataset.min_bucket_reso, - "max_bucket_reso": dataset.max_bucket_reso, - "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, - } - - subsets_metadata = [] - for subset in dataset.subsets: - subset_metadata = { - "img_count": subset.img_count, - "num_repeats": subset.num_repeats, - "color_aug": bool(subset.color_aug), - "flip_aug": bool(subset.flip_aug), - "random_crop": bool(subset.random_crop), - "shuffle_caption": bool(subset.shuffle_caption), - "keep_tokens": subset.keep_tokens, - } - image_dir_or_metadata_file = None - if subset.image_dir: - image_dir = os.path.basename(subset.image_dir) - subset_metadata["image_dir"] = image_dir - image_dir_or_metadata_file = image_dir - - if is_dreambooth_dataset: - subset_metadata["class_tokens"] = subset.class_tokens - subset_metadata["is_reg"] = subset.is_reg - if subset.is_reg: - image_dir_or_metadata_file = None # not merging reg dataset - else: - metadata_file = os.path.basename(subset.metadata_file) - subset_metadata["metadata_file"] = metadata_file - image_dir_or_metadata_file = metadata_file # may overwrite - - subsets_metadata.append(subset_metadata) - - # merge dataset dir: not reg subset only - # TODO update additional-network extension to show detailed dataset config from metadata - if image_dir_or_metadata_file is not None: - # datasets may have a certain dir multiple times - v = image_dir_or_metadata_file - i = 2 - while v in dataset_dirs_info: - v = image_dir_or_metadata_file + f" ({i})" - i += 1 - image_dir_or_metadata_file = v - - dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} - - dataset_metadata["subsets"] = subsets_metadata - datasets_metadata.append(dataset_metadata) - - # merge tag frequency: - for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): - # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える - # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない - # なので、ここで複数datasetの回数を合算してもあまり意味はない - if ds_dir_name in tag_frequency: - continue - tag_frequency[ds_dir_name] = ds_freq_for_dir - - metadata["ss_datasets"] = json.dumps(datasets_metadata) - metadata["ss_tag_frequency"] = json.dumps(tag_frequency) - metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) - else: - # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir - assert ( - len(train_dataset_group.datasets) == 1 - ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" - - dataset = train_dataset_group.datasets[0] - - dataset_dirs_info = {} - reg_dataset_dirs_info = {} - if use_dreambooth_method: - for subset in dataset.subsets: - info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info - info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + network.to(weight_dtype) + + unet.requires_grad_(False) + unet.to(dtype=weight_dtype) + for t_enc in text_encoders: + t_enc.requires_grad_(False) + + # acceleratorがなんかよろしくやってくれるらしい + # TODO めちゃくちゃ冗長なのでコードを整理する + if train_unet and train_text_encoder: + if len(text_encoders) > 1: + unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler + ) + text_encoder = text_encoders = [t_enc1, t_enc2] + del t_enc1, t_enc2 + else: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + text_encoders = [text_encoder] + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_text_encoder: + if len(text_encoders) > 1: + t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler + ) + text_encoder = text_encoders = [t_enc1, t_enc2] + del t_enc1, t_enc2 + else: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + text_encoders = [text_encoder] + + unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator else: - for subset in dataset.subsets: - dataset_dirs_info[os.path.basename(subset.metadata_file)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count, - } + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) - metadata.update( - { - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_resolution": args.resolution, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_enable_bucket": bool(dataset.enable_bucket), - "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), - "ss_min_bucket_reso": dataset.min_bucket_reso, - "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(dataset.tag_frequency), - "ss_bucket_info": json.dumps(dataset.bucket_info), - } - ) + # transform DDP after prepare (train_network here only) + text_encoders = train_util.transform_models_if_DDP(text_encoders) + unet, network = train_util.transform_models_if_DDP([unet, network]) - # add extra args - if args.network_args: - metadata["ss_network_args"] = json.dumps(net_kwargs) - - # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path - if os.path.exists(sd_model_name): - metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) - metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) - sd_model_name = os.path.basename(sd_model_name) - metadata["ss_sd_model_name"] = sd_model_name - - if args.vae is not None: - vae_name = args.vae - if os.path.exists(vae_name): - metadata["ss_vae_hash"] = train_util.model_hash(vae_name) - metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) - vae_name = os.path.basename(vae_name) - metadata["ss_vae_name"] = vae_name - - metadata = {k: str(v) for k, v in metadata.items()} - - # make minimum metadata for filtering - minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] - minimum_metadata = {} - for key in minimum_keys: - if key in metadata: - minimum_metadata[key] = metadata[key] - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.gradient_checkpointing: + # according to TI example in Diffusers, train is required + unet.train() + for t_enc in text_encoders: + t_enc.train() - if accelerator.is_main_process: - accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) + # set top parameter requires_grad = True for gradient checkpointing works + if train_text_encoder: + t_enc.text_model.embeddings.requires_grad_(True) - loss_list = [] - loss_total = 0.0 - del train_dataset_group + # set top parameter requires_grad = True for gradient checkpointing works + if not train_text_encoder: # train U-Net only + unet.parameters().__next__().requires_grad_(True) + else: + unet.eval() + for t_enc in text_encoders: + t_enc.eval() + + del t_enc + + network.prepare_grad_etc(text_encoder, unet) + + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not have alpha + "ss_network_dropout": args.network_dropout, # some networks may not have dropout + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_base_model_version": model_version, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_multires_noise_iterations": args.multires_noise_iterations, + "ss_multires_noise_discount": args.multires_noise_discount, + "ss_adaptive_noise_scale": args.adaptive_noise_scale, + "ss_zero_terminal_snr": args.zero_terminal_snr, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + "ss_min_snr_gamma": args.min_snr_gamma, + "ss_scale_weight_norms": args.scale_weight_norms, + "ss_ip_noise_gamma": args.ip_noise_gamma, + } - # callback for step start - if hasattr(network, "on_step_start"): - on_step_start = network.on_step_start - else: - on_step_start = lambda *args, **kwargs: None + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } - # function for saving/removing - def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } - print(f"\nsaving checkpoint: {ckpt_file}") - metadata["ss_training_finished_at"] = str(time.time()) - metadata["ss_steps"] = str(steps) - metadata["ss_epoch"] = str(epoch_no) + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える + # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない + # なので、ここで複数datasetの回数を合算してもあまり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" - unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + dataset = train_dataset_group.datasets[0] - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_metadata = {} + for key in train_util.SS_METADATA_MINIMUM_KEYS: + if key in metadata: + minimum_metadata[key] = metadata[key] + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) - metadata["ss_epoch"] = str(epoch + 1) + loss_list = [] + loss_total = 0.0 + del train_dataset_group - network.on_epoch_start(text_encoder, unet) + # callback for step start + if hasattr(network, "on_step_start"): + on_step_start = network.on_step_start + else: + on_step_start = lambda *args, **kwargs: None + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + metadata["ss_training_finished_at"] = str(time.time()) + metadata["ss_steps"] = str(steps) + metadata["ss_epoch"] = str(epoch_no) + + metadata_to_save = minimum_metadata if args.no_metadata else metadata + sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + metadata_to_save.update(sai_metadata) + + unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + metadata["ss_epoch"] = str(epoch + 1) + + network.on_epoch_start(text_encoder, unet) + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + on_step_start(text_encoder, unet) + + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(network): - on_step_start(text_encoder, unet) + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(train_text_encoder): - # Get the text embedding for conditioning - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( - args.scale_weight_norms, accelerator.device - ) - max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} - else: - keys_scaled, mean_norm, maximum_norm = None, None, None + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + if args.scale_weight_norms: + keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet - ) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, unwrap_model(network), global_step, epoch) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) + if args.scale_weight_norms: + progress_bar.set_postfix(**{**max_mean_logs, **logs}) - if global_step >= args.max_train_steps: - break + if args.logging_dir is not None: + logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + accelerator.log(logs, step=global_step) - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch + 1) + if global_step >= args.max_train_steps: + break - accelerator.wait_for_everyone() + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1) + accelerator.wait_for_everyone() - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - # end of epoch + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - # metadata["ss_epoch"] = str(num_train_epochs) - metadata["ss_training_finished_at"] = str(time.time()) + # end of epoch - if is_main_process: - network = unwrap_model(network) + # metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) - accelerator.end_training() + if is_main_process: + network = accelerator.unwrap_model(network) - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) + accelerator.end_training() - del accelerator # この後メモリを使うのでこれは消す + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -861,6 +988,11 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser @@ -870,4 +1002,5 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + trainer = NetworkTrainer() + trainer.train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d08251e12..1c7b7fcb2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,16 +1,23 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value +import toml from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler +from transformers import CLIPTokenizer +from library import model_util import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -23,9 +30,8 @@ from library.custom_train_functions import ( apply_snr_weight, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, ) imagenet_templates_small = [ @@ -81,504 +87,641 @@ ] -def train(args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template - - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - - cache_latents = args.cache_latents +class TextualInversionTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 + self.is_sdxl = False - if args.seed is not None: - set_seed(args.seed) + def assert_extra_args(self, args, train_dataset_group): + pass - tokenizer = train_util.load_tokenizer(args) + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + def load_tokenizer(self, args): + tokenizer = train_util.load_tokenizer(args) + return tokenizer - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): + pass - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + with torch.enable_grad(): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) + return encoder_hidden_states - # Convert the init_word to token_id - if args.init_word is not None: - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" - ) - else: - init_token_ids = None - - # add new word to tokenizer, count is num_vectors_per_token - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == args.num_vectors_per_token - ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids): - token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - - # load weights - if args.weights is not None: - embeddings = load_weights(args.weights) - assert len(token_ids) == len( - embeddings - ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) - for token_id, embedding in zip(token_ids, embeddings): - token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") - - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } - else: - print("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - print("use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) - - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - else: - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - trainable_params = text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collater, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noise_pred = unet(noisy_latents, timesteps, text_conds).sample + return noise_pred - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + train_util.sample_images( + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + def save_weights(self, file, updated_embs, save_dtype, metadata): + state_dict = {"emb_params": updated_embs[0]} - # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) - - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - - index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - else: - unet.eval() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - text_encoder.to(weight_dtype) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) - - # function for saving/removing - def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"\nsaving checkpoint: {ckpt_file}") - save_weights(ckpt_file, embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - text_encoder.train() - - loss_total = 0 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(text_encoder): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - # use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) # can be loaded in Web UI - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ - index_no_updates - ] + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + return [emb] - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, updated_embs, global_step, epoch) + def train(self, args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) + cache_latents = args.cache_latents - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) - accelerator.log(logs, step=global_step) + if args.seed is not None: + set_seed(args.seed) - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer + tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] - if global_step >= args.max_train_steps: - break + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch + 1) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - accelerator.wait_for_everyone() + # モデルを読み込む + model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: + accelerator.print( + "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " + + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" + ) - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if accelerator.is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, updated_embs, epoch + 1, global_step) + # Convert the init_word to token_id + init_token_ids_list = [] + if args.init_word is not None: + for i, tokenizer in enumerate(tokenizers): + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + accelerator.print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / " + + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}" + ) + init_token_ids_list.append(init_token_ids) + else: + init_token_ids_list = [None] * len(tokenizers) + + # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token + # token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される + # add new word to tokenizer, count is num_vectors_per_token + # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added + + self.assert_token_string(args.token_string, tokenizers) + + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + token_ids_list = [] + token_embeds_list = [] + for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)): + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered : tokenizer {i+1}, {token_ids}" + assert ( + len(tokenizer) - 1 == token_ids[-1] + ), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}" + token_ids_list.append(token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + token_embeds_list.append(token_embeds) + + # load weights + if args.weights is not None: + embeddings_list = self.load_weights(args.weights) + assert len(token_ids) == len( + embeddings_list[0] + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # accelerator.print(token_ids, embeddings.size()) + for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list): + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + accelerator.print(f"weighs loaded") + + accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + if args.dataset_config is not None: + accelerator.print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + accelerator.print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + accelerator.print("Use DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) + self.assert_extra_args(args, train_dataset_group) - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - train_util.sample_images( - accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + accelerator.print(f"use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + # サンプル生成用 + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + # サンプル生成用 + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + for text_encoder in text_encoders: + text_encoder.gradient_checkpointing_enable() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + trainable_params = [] + for text_encoder in text_encoders: + trainable_params += text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, ) - # end of epoch + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = unwrap_model(text_encoder) + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) - accelerator.end_training() + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) + # acceleratorがなんかよろしくやってくれるらしい + if len(text_encoders) == 1: + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler + ) + # transform DDP after prepare + text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + elif len(text_encoders) == 2: + text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler + ) + # transform DDP after prepare + text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet) - del accelerator # この後メモリを使うのでこれは消す + text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) + else: + raise NotImplementedError() + + index_no_updates_list = [] + orig_embeds_params_list = [] + for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders): + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + index_no_updates_list.append(index_no_updates) + + # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params_list.append(orig_embeds_params) + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + unet.train() + else: + unet.eval() + + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + for text_encoder in text_encoders: + text_encoder.to(weight_dtype) + if args.full_bf16: + for text_encoder in text_encoders: + text_encoder.to(weight_dtype) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - print("model saved.") + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) -def save_weights(file, updated_embs, save_dtype): - state_dict = {"emb_params": updated_embs} + # function for saving/removing + def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + + sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, False, True) + + self.save_weights(ckpt_file, embs_list, save_dtype, sai_metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for text_encoder in text_encoders: + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoders[0]): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + latents = latents * self.vae_scale_factor + + # Get the text embedding for conditioning + text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + for text_encoder, orig_embeds_params, index_no_updates in zip( + text_encoders, orig_embeds_params_list, index_no_updates_list + ): + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + self.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) - save_file(state_dict, file) - else: - torch.save(state_dict, file) # can be loaded in Web UI + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + updated_embs_list = [] + for text_encoder, token_ids in zip(text_encoders, token_ids_list): + updated_embs = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[token_ids] + .data.detach() + .clone() + ) + updated_embs_list.append(updated_embs) + + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, updated_embs_list, global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs_list = [] + for text_encoder, token_ids in zip(text_encoders, token_ids_list): + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs_list.append(updated_embs) + + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if accelerator.is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, updated_embs_list, epoch + 1, global_step) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + self.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + + # end of epoch -def load_weights(file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = accelerator.unwrap_model(text_encoder) - data = load_file(file) - else: - # compatible to Web UI's file format - data = torch.load(file, map_location="cpu") - if type(data) != dict: - raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") + accelerator.end_training() - if "string_to_param" in data: # textual inversion embeddings - data = data["string_to_param"] - if hasattr(data, "_parameters"): # support old PyTorch? - data = getattr(data, "_parameters") + if args.save_state and is_main_process: + train_util.save_state_on_train_end(args, accelerator) - emb = next(iter(data.values())) - if type(emb) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if len(emb.size()) == 1: - emb = emb.unsqueeze(0) + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - return emb + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -620,6 +763,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser @@ -630,4 +778,5 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + trainer = TextualInversionTrainer() + trainer.train(args) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index f44d565cc..2c5673be1 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,9 +8,17 @@ from tqdm import tqdm import torch +try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + from library.ipex import ipex_init + ipex_init() +except Exception: + pass from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler +import library import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -27,6 +35,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) +import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -107,7 +116,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -187,7 +196,7 @@ def train(args): print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -232,7 +241,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print("use template for training captions. is object: {args.use_object_template}") + print(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -265,10 +274,10 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + original_unet.UNet2DConditionModel.forward = unet_forward_XTI + original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI + original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI # 学習を準備する if cache_latents: @@ -328,7 +337,7 @@ def train(args): index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) @@ -382,9 +391,14 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): @@ -433,20 +447,9 @@ def remove_model(old_ckpt_name): ] ) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # Predict the noise residual with accelerator.autocast(): @@ -482,7 +485,7 @@ def remove_model(old_ckpt_name): # Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ index_no_updates ] @@ -499,7 +502,13 @@ def remove_model(old_ckpt_name): if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + updated_embs = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[token_ids_XTI] + .data.detach() + .clone() + ) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) save_model(ckpt_name, updated_embs, global_step, epoch) @@ -515,7 +524,9 @@ def remove_model(old_ckpt_name): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) @@ -535,7 +546,7 @@ def remove_model(old_ckpt_name): accelerator.wait_for_everyone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs @@ -560,7 +571,7 @@ def remove_model(old_ckpt_name): is_main_process = accelerator.is_main_process if is_main_process: - text_encoder = unwrap_model(text_encoder) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training()