diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index bd4ef334e..c9edf2650 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.16.26 + uses: crate-ci/typos@v1.17.2 diff --git a/README-ja.md b/README-ja.md index 29c33a659..f70f882d7 100644 --- a/README-ja.md +++ b/README-ja.md @@ -1,7 +1,3 @@ -SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 - -SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 - ## リポジトリについて Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 @@ -21,6 +17,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど * [データセット設定](./docs/config_README-ja.md) +* [SDXL学習](./docs/train_SDXL-en.md) (英語版) * [DreamBoothの学習について](./docs/train_db_README-ja.md) * [fine-tuningのガイド](./docs/fine_tune_README_ja.md): * [LoRAの学習について](./docs/train_network_README-ja.md) @@ -44,9 +41,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 - -以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 +スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) @@ -59,20 +54,20 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` コマンドプロンプトでも同一です。 -(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) +注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 -accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) +この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。 -※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。 +accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) ```txt - This machine @@ -87,41 +82,6 @@ accelerate configの質問には以下のように答えてください。(bf1 ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) -### オプション:`bitsandbytes`(8bit optimizer)を使う - -`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 - -Windowsでは0.35.0または0.41.1を推奨します。 - -- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 -- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 - -注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 - -以下の手順に従い、`bitsandbytes`をインストールしてください。 - -### 0.35.0を使う場合 - -PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 - -```powershell -cd sd-scripts -.\venv\Scripts\activate -pip install bitsandbytes==0.35.0 - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -``` - -### 0.41.1を使う場合 - -jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 - -```powershell -python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui -``` - ## アップグレード 新しいリリースがあった場合、以下のコマンドで更新できます。 @@ -151,4 +111,3 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause - diff --git a/README.md b/README.md index 5919f08ca..b1b924ec2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training). - This repository contains training, generation and utility scripts for Stable Diffusion. [__Change History__](#change-history) is moved to the bottom of the page. @@ -20,9 +18,9 @@ This repository contains the scripts for: ## About requirements.txt -These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.) +The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below. -The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work. +The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work. ## Links to usage documentation @@ -32,11 +30,13 @@ Most of the documents are written in Japanese. * [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc... * [Chinese version](./docs/train_README-zh.md) +* [SDXL training](./docs/train_SDXL-en.md) (English version) * [Dataset config](./docs/config_README-ja.md) + * [English version](./docs/config_README-en.md) * [DreamBooth training guide](./docs/train_db_README-ja.md) * [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md): -* [training LoRA](./docs/train_network_README-ja.md) -* [training Textual Inversion](./docs/train_ti_README-ja.md) +* [Training LoRA](./docs/train_network_README-ja.md) +* [Training Textual Inversion](./docs/train_ti_README-ja.md) * [Image generation](./docs/gen_img_README-ja.md) * note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad) @@ -64,14 +64,18 @@ cd sd-scripts python -m venv venv .\venv\Scripts\activate -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 accelerate config ``` -__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section. +If `python -m venv` shows only `python`, change `python` to `py`. + +__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. + +This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`. [[datasets]] No.2 ┘ +``` + +The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`. + +The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding. + +Additionally, the available options may vary depending on the method that the learning approach supports. + +* Options specific to the DreamBooth method +* Options specific to the fine-tuning method +* Options available when using the caption dropout technique + +When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both. +When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset. +In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets. + +In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem. + +Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs. + +### Common options for all learning methods + +These are options that can be specified regardless of the learning method. + +#### Data set specific options + +These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`. + + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * This corresponds to the command-line argument `--train_batch_size`. + +These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. + +#### Options for Subsets + +These options are related to subset configuration. + +| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o | +| `caption_suffix` | `", from side"` | o | o | o | +| `caption_separator` | (not specified) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | + +* `num_repeats` + * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method. +* `caption_prefix`, `caption_suffix` + * Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`. +* `caption_separator` + * Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set. +* `keep_tokens_separator` + * Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc. +* `secondary_separator` + * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together. +* `enable_wildcard` + * Enables wildcard notation. This will be explained later. + +### DreamBooth-specific options + +DreamBooth-specific options only exist as subsets-specific options. + +#### Subset-specific options + +Options related to the configuration of DreamBooth subsets. + +| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o (required) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `"sks girl"` | - | - | o | +| `is_reg` | `false` | - | - | o | + +Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`. + +* `image_dir` + * Specifies the path to the image directory. This is a required option. + * Images must be placed directly under the directory. +* `class_tokens` + * Sets the class tokens. + * Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur. +* `is_reg` + * Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization. + +### Fine-tuning method specific options + +The options for the fine-tuning method only exist for subset-specific options. + +#### Subset-specific options + +These options are related to the configuration of the fine-tuning method's subsets. + +| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `'C:\hoge'` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) | + +* `image_dir` + * Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so. + * The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file. + * The images must be placed directly under the directory. +* `metadata_file` + * Specify the path to the metadata file used for the subset. This is a required option. + * It is equivalent to the command-line argument `--in_json`. + * Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets. + +### Options available when caption dropout method can be used + +The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified. + +#### Subset-specific options + +Options related to the setting of subsets that caption dropout can be used for. + +| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## Behavior when there are duplicate subsets + +In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored. + +However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions. + +```toml +# If data sets exist separately, they are not considered duplicates and are both used for training. + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## Command Line Argument and Configuration File + +There are options in the configuration file that have overlapping roles with command line argument options. + +The following command line argument options are ignored if a configuration file is passed: + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file. + +| Command Line Argument Option | Prioritized Configuration File Option | +| ------------------------------- | ------------------------------------- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop` | | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## Error Guide + +Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem. + +As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug. + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name. + * The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting. +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful. +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it. + +## Miscellaneous + +### Multi-line captions + +By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption. + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +It can be combined with wildcard notation. + +In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format. + +The tags in the metadata (`tags`) are added to each line of the caption. + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` + +In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc. + +### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc. + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### Example of caption, secondary_separator notation: `secondary_separator = ";;;"` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped). + +### Example of caption, enable_wildcard notation: `enable_wildcard = true` + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`. + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`). + +### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"` + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc. + diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index 69a03f6cf..b57ae86a7 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -1,5 +1,3 @@ -For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. - `--dataset_config` で渡すことができる設定ファイルに関する説明です。 ## 概要 @@ -140,12 +138,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `shuffle_caption` | `true` | o | o | o | | `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | | `caption_suffix` | `“, from side”` | o | o | o | +| `caption_separator` | (通常は設定しません) | o | o | o | +| `keep_tokens_separator` | `“|||”` | o | o | o | +| `secondary_separator` | `“;;;”` | o | o | o | +| `enable_wildcard` | `true` | o | o | o | * `num_repeats` * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。 * `caption_prefix`, `caption_suffix` * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。 +* `caption_separator` + * タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。 + +* `keep_tokens_separator` + * キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。 + +* `secondary_separator` + * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。 + +* `enable_wildcard` + * ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。 + ### DreamBooth 方式専用のオプション DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。 @@ -280,4 +294,89 @@ resolution = 768 * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。 * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。 +## その他 + +### 複数行キャプション + +`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage +a girl with a microphone standing on a stage +detailed digital art of a girl with a microphone on a stage +``` + +ワイルドカード記法と組み合わせることも可能です。 + +メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。 + +メタデータのタグ (`tags`) は、キャプションの各行に追加されます。 + +```json +{ + "/path/to/image.png": { + "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2", + "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus" + }, + ... +} +``` +この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。 + +### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等 + +```toml +[general] +flip_aug = true +color_aug = false +resolution = [1024, 1024] + +[[datasets]] +batch_size = 6 +enable_bucket = true +bucket_no_upscale = true +caption_extension = ".txt" +keep_tokens_separator= "|||" +shuffle_caption = true +caption_tag_dropout_rate = 0.1 +secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side +enable_wildcard = true # 同上 / same as above + + [[datasets.subsets]] + image_dir = "/path/to/image_dir" + num_repeats = 1 + + # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically) + caption_prefix = "1girl, hatsune miku, vocaloid |||" + + # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains + # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself + caption_suffix = ", anime screencap ||| masterpiece, rating: general" +``` + +### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors +``` +`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。 + +### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合 + +```txt +1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background +``` +ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。 + +```txt +1girl, hatsune miku, vocaloid, {{retro style}} +``` +タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。 + +### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合 + +```txt +1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general +``` +`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。 diff --git a/docs/gen_img_README-ja.md b/docs/gen_img_README-ja.md index cf35f1df7..8f4442d00 100644 --- a/docs/gen_img_README-ja.md +++ b/docs/gen_img_README-ja.md @@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt - `--network_show_meta` : 追加ネットワークのメタデータを表示します。 + +--- + +# About Gradual Latent + +Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options. + +- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first. +- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size. +- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0. +- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps. + +Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`. + +__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers. + +It is more effective with SD 1.5. It is quite subtle with SDXL. + +# Gradual Latent について + +latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。 + +- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。 +- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。 +- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。 +- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。 + +それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。 + +サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。 + +SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。 + diff --git a/docs/train_SDXL-en.md b/docs/train_SDXL-en.md new file mode 100644 index 000000000..a4c55b3fd --- /dev/null +++ b/docs/train_SDXL-en.md @@ -0,0 +1,84 @@ +## SDXL training + +The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL. + +### Training scripts for SDXL + +- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. + - `--full_bf16` option is added. Thanks to KohakuBlueleaf! + - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. + - The full bfloat16 training might be unstable. Please use it at your own risk. + - The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. + - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. +- `prepare_buckets_latents.py` now supports SDXL fine-tuning. + +- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. + +- Both scripts has following additional options: + - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. + - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. + +- `--weighted_captions` option is not supported yet for both scripts. + +- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. + - `--cache_text_encoder_outputs` is not supported. + - There are two options for captions: + 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. + 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. + - See below for the format of the embeddings. + +- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. + +### Utility scripts for SDXL + +- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. + - The options are almost the same as `sdxl_train.py'. See the help message for the usage. + - Please launch the script as follows: + `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` + - This script should work with multi-GPU, but it is not tested in my environment. + +- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. + - The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage. + +- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage. + +### Tips for SDXL training + +- The default resolution of SDXL is 1024x1024. +- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. +- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: + - Train U-Net only. + - Use gradient checkpointing. + - Use `--cache_text_encoder_outputs` option and caching latents. + - Use one of 8bit optimizers or Adafactor optimizer. + - Use lower dim (4 to 8 for 8GB GPU). +- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. +- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. +- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. + +Example of the optimizer settings for Adafactor with the fixed learning rate: +```toml +optimizer_type = "adafactor" +optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] +lr_scheduler = "constant_with_warmup" +lr_warmup_steps = 100 +learning_rate = 4e-7 # SDXL original learning rate +``` + +### Format of Textual Inversion embeddings for SDXL + +```python +from safetensors.torch import save_file + +state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} +save_file(state_dict, file) +``` + +### ControlNet-LLLite + +ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. + diff --git a/fine_tune.py b/fine_tune.py index be61b3d16..46f128287 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -2,27 +2,27 @@ # XXX dropped option: hypernetwork training import argparse -import gc import math import os from multiprocessing import Value import toml from tqdm import tqdm -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + import library.train_util as train_util import library.config_util as config_util from library.config_util import ( @@ -42,6 +42,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -54,11 +55,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -91,7 +92,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -102,7 +103,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -163,9 +164,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -212,8 +211,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -228,7 +227,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -292,7 +293,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -456,7 +457,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -466,12 +467,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -480,7 +482,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--learning_rate_te", diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py index 7851fb08b..7d192cb26 100644 --- a/finetune/blip/blip.py +++ b/finetune/blip/blip.py @@ -21,6 +21,10 @@ import os from urllib.parse import urlparse from timm.models.hub import download_cached_file +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class BLIP_Base(nn.Module): def __init__(self, @@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename): del state_dict[key] msg = model.load_state_dict(state_dict,strict=False) - print('load checkpoint from %s'%url_or_filename) + logger.info('load checkpoint from %s'%url_or_filename) return model,msg diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 68839eccc..5aeb17425 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -8,6 +8,10 @@ import re from tqdm import tqdm +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') @@ -36,13 +40,13 @@ def clean_tags(image_key, tags): tokens = tags.split(", rating") if len(tokens) == 1: # WD14 taggerのときはこちらになるのでメッセージは出さない - # print("no rating:") - # print(f"{image_key} {tags}") + # logger.info("no rating:") + # logger.info(f"{image_key} {tags}") pass else: if len(tokens) > 2: - print("multiple ratings:") - print(f"{image_key} {tags}") + logger.info("multiple ratings:") + logger.info(f"{image_key} {tags}") tags = tokens[0] tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 @@ -124,43 +128,43 @@ def clean_caption(caption): def main(args): if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.load(f) else: - print("no metadata / メタデータファイルがありません") + logger.error("no metadata / メタデータファイルがありません") return - print("cleaning captions and tags.") + logger.info("cleaning captions and tags.") image_keys = list(metadata.keys()) for image_key in tqdm(image_keys): tags = metadata[image_key].get('tags') if tags is None: - print(f"image does not have tags / メタデータにタグがありません: {image_key}") + logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") else: org = tags tags = clean_tags(image_key, tags) metadata[image_key]['tags'] = tags if args.debug and org != tags: - print("FROM: " + org) - print("TO: " + tags) + logger.info("FROM: " + org) + logger.info("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: - print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") + logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: org = caption caption = clean_caption(caption) metadata[image_key]['caption'] = caption if args.debug and org != caption: - print("FROM: " + org) - print("TO: " + caption) + logger.info("FROM: " + org) + logger.info("TO: " + caption) # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding='utf-8') as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -178,10 +182,10 @@ def setup_parser() -> argparse.ArgumentParser: args, unknown = parser.parse_known_args() if len(unknown) == 1: - print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") - print("All captions and tags in the metadata are processed.") - print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") - print("メタデータ内のすべてのキャプションとタグが処理されます。") + logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") + logger.warning("All captions and tags in the metadata are processed.") + logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") + logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") args.in_json = args.out_json args.out_json = unknown[0] elif len(unknown) > 0: diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 074576bc2..489bdbcce 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -9,14 +9,22 @@ from PIL import Image from tqdm import tqdm import numpy as np + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms from torchvision.transforms.functional import InterpolationMode sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_SIZE = 384 @@ -47,7 +55,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor = IMAGE_TRANSFORM(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -74,21 +82,21 @@ def main(args): args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path cwd = os.getcwd() - print("Current Working Directory is: ", cwd) + logger.info(f"Current Working Directory is: {cwd}") os.chdir("finetune") if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): args.caption_weights = os.path.join("..", args.caption_weights) - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") - print(f"loading BLIP caption: {args.caption_weights}") + logger.info(f"loading BLIP caption: {args.caption_weights}") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) - print("BLIP loaded") + logger.info("BLIP loaded") # captioningする def run_batch(path_imgs): @@ -108,7 +116,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f'{image_path} {caption}') # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -138,7 +146,7 @@ def run_batch(path_imgs): raw_image = raw_image.convert("RGB") img_tensor = IMAGE_TRANSFORM(raw_image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, img_tensor)) @@ -148,7 +156,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..edeebadf3 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -5,12 +5,19 @@ from pathlib import Path from PIL import Image from tqdm import tqdm + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from transformers import AutoProcessor, AutoModelForCausalLM from transformers.generation.utils import GenerationMixin import library.train_util as train_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -35,8 +42,8 @@ def remove_words(captions, debug): for pat in PATTERN_REPLACE: cap = pat.sub("", cap) if debug and cap != caption: - print(caption) - print(cap) + logger.info(caption) + logger.info(cap) removed_caps.append(cap) return removed_caps @@ -70,16 +77,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch """ - print(f"load images from {args.train_data_dir}") + logger.info(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") + logger.info(f"loading GIT: {args.model_id}") git_processor = AutoProcessor.from_pretrained(args.model_id) git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") + logger.info("GIT loaded") # captioningする def run_batch(path_imgs): @@ -97,7 +104,7 @@ def run_batch(path_imgs): with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: f.write(caption + "\n") if args.debug: - print(image_path, caption) + logger.info(f"{image_path} {caption}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -126,7 +133,7 @@ def run_batch(path_imgs): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -137,7 +144,7 @@ def run_batch(path_imgs): if len(b_imgs) > 0: run_batch(b_imgs) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 241f6f902..89f717473 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -5,72 +5,96 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging -def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") - - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json - - if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") - else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} +setup_logging() +import logging - print("merge caption texts to metadata json.") - for image_path in tqdm(image_paths): - caption_path = image_path.with_suffix(args.caption_extension) - caption = caption_path.read_text(encoding='utf-8').strip() +logger = logging.getLogger(__name__) - if not os.path.exists(caption_path): - caption_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} +def main(args): + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - metadata[image_key]['caption'] = caption - if args.debug: - print(image_key, caption) + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - print("done!") + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--debug", action="store_true", help="debug mode") + logger.info("merge caption texts to metadata json.") + for image_path in tqdm(image_paths): + caption_path = image_path.with_suffix(args.caption_extension) + caption = caption_path.read_text(encoding="utf-8").strip() - return parser + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} -if __name__ == '__main__': - parser = setup_parser() + metadata[image_key]["caption"] = caption + if args.debug: + logger.info(f"{image_key} {caption}") - args = parser.parse_args() + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") + logger.info("done!") - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - main(args) +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子" + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index db1bff6da..ce22d990e 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -5,67 +5,89 @@ from tqdm import tqdm import library.train_util as train_util import os +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + assert not args.recursive or ( + args.recursive and args.full_path + ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + logger.info(f"found {len(image_paths)} images.") - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json + if args.in_json is None and Path(args.out_json).is_file(): + args.in_json = args.out_json - if args.in_json is not None: - print(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") - else: - print("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} + if args.in_json is not None: + logger.info(f"loading existing metadata: {args.in_json}") + metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) + logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") + else: + logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") + metadata = {} - print("merge tags to metadata json.") - for image_path in tqdm(image_paths): - tags_path = image_path.with_suffix(args.caption_extension) - tags = tags_path.read_text(encoding='utf-8').strip() + logger.info("merge tags to metadata json.") + for image_path in tqdm(image_paths): + tags_path = image_path.with_suffix(args.caption_extension) + tags = tags_path.read_text(encoding="utf-8").strip() - if not os.path.exists(tags_path): - tags_path = os.path.join(image_path, args.caption_extension) + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[image_key]['tags'] = tags - if args.debug: - print(image_key, tags) + metadata[image_key]["tags"] = tags + if args.debug: + logger.info(f"{image_key} {tags}") - # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + # metadataを書き出して終わり + logger.info(f"writing metadata: {args.out_json}") + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--caption_extension", type=str, default=".txt", - help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode, print tags") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument( + "--in_json", + type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + parser.add_argument( + "--caption_extension", + type=str, + default=".txt", + help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子", + ) + parser.add_argument("--debug", action="store_true", help="debug mode, print tags") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..0389da388 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -8,13 +8,21 @@ import numpy as np from PIL import Image import cv2 + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms import library.model_util as model_util import library.train_util as train_util +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( [ @@ -51,22 +59,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive): def main(args): # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") if args.bucket_reso_steps % 32 > 0: - print( + logger.warning( f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" ) train_data_dir_path = Path(args.train_data_dir) image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") + logger.info(f"loading existing metadata: {args.in_json}") with open(args.in_json, "rt", encoding="utf-8") as f: metadata = json.load(f) else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") + logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}") return weight_dtype = torch.float32 @@ -89,7 +97,7 @@ def main(args): if not args.bucket_no_upscale: bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -130,7 +138,7 @@ def process_batch(is_last): if image.mode != "RGB": image = image.convert("RGB") except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] @@ -183,15 +191,15 @@ def process_batch(is_last): for i, reso in enumerate(bucket_manager.resos): count = bucket_counts.get(reso, 0) if count > 0: - print(f"bucket {i} {reso}: {count}") + logger.info(f"bucket {i} {reso}: {count}") img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") + logger.info(f"mean ar error: {np.mean(img_ar_errors)}") # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") + logger.info(f"writing metadata: {args.out_json}") with open(args.out_json, "wt", encoding="utf-8") as f: json.dump(metadata, f, indent=2) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index fbf328e83..401c6d1ec 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,6 +11,12 @@ from tqdm import tqdm import library.train_util as train_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # from wd14 tagger IMAGE_SIZE = 448 @@ -58,7 +64,7 @@ def __getitem__(self, idx): image = preprocess_image(image) tensor = torch.tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor, img_path) @@ -75,36 +81,44 @@ def collate_fn_remove_corrupted(batch): def main(args): + # model location is model_dir + repo_id + # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash + model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + if not os.path.exists(model_location) or args.force_download: + os.makedirs(args.model_dir, exist_ok=True) + logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: + files = ["selected_tags.csv"] files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(model_location, SUB_DIR), + force_download=True, + force_filename=file, + ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) + hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) else: - print("using existing wd14 tagger model") + logger.info("using existing wd14 tagger model") # 画像を読み込む if args.onnx: + import torch import onnx import onnxruntime as ort - onnx_path = f"{args.model_dir}/model.onnx" - print("Running wd14 tagger with onnx") - print(f"loading onnx model: {onnx_path}") + onnx_path = f"{model_location}/model.onnx" + logger.info("Running wd14 tagger with onnx") + logger.info(f"loading onnx model: {onnx_path}") if not os.path.exists(onnx_path): raise Exception( @@ -119,9 +133,9 @@ def main(args): except: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str: + if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0: # some rebatch model may use 'N' as dynamic axes - print( + logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size @@ -130,19 +144,19 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, - providers=["CUDAExecutionProvider"] - if "CUDAExecutionProvider" in ort.get_available_providers() - else ["CPUExecutionProvider"], + providers=( + ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"] + ), ) else: from tensorflow.keras.models import load_model - model = load_model(f"{args.model_dir}") + model = load_model(f"{model_location}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) l = [row for row in reader] header = l[0] # tag_id,name,category,count @@ -156,7 +170,7 @@ def main(args): train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - print(f"found {len(image_paths)} images.") + logger.info(f"found {len(image_paths)} images.") tag_freq = {} @@ -168,8 +182,8 @@ def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - if len(imgs) < args.batch_size: - imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + # if len(imgs) < args.batch_size: + # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -237,7 +251,10 @@ def run_batch(path_imgs): with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -269,7 +286,7 @@ def run_batch(path_imgs): image = image.convert("RGB") image = preprocess_image(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") continue b_imgs.append((image_path, image)) @@ -284,11 +301,11 @@ def run_batch(path_imgs): if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - print("\nTag frequencies:") + print("Tag frequencies:") for tag, freq in sorted_tags: print(f"{tag}: {freq}") - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: @@ -307,7 +324,9 @@ def setup_parser() -> argparse.ArgumentParser: help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", ) parser.add_argument( - "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + "--force_download", + action="store_true", + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( @@ -322,8 +341,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", ) - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument( + "--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子" + ) + parser.add_argument( + "--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値" + ) parser.add_argument( "--general_threshold", type=float, @@ -336,7 +359,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", ) - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + parser.add_argument( + "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" + ) parser.add_argument( "--remove_underscore", action="store_true", @@ -349,9 +374,13 @@ def setup_parser() -> argparse.ArgumentParser: default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument( + "--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する" + ) parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") - parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" + ) parser.add_argument( "--caption_separator", type=str, diff --git a/gen_img.py b/gen_img.py new file mode 100644 index 000000000..4fe898716 --- /dev/null +++ b/gen_img.py @@ -0,0 +1,3334 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import importlib.util +import sys +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch + +from library.device_utils import init_ipex, clean_memory, get_preferred_device + +init_ipex() + +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +import tools.original_control_net as original_control_net +from tools.original_control_net import ControlNetInfo +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + logger.info("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + logger.info("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + logger.info("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + logger.info("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + logger.info("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + is_sdxl, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: InferSdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.is_sdxl = is_sdxl + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: Union[InferUNet2DConditionModel, InferSdxlUNet2DConditionModel] = unet + self.scheduler = scheduler + self.safety_checker = None + + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 + self.control_net_lllites: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + self.gradual_latent: GradualLatent = None + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + def set_control_net_lllites(self, ctrl_net_lllites): + self.control_net_lllites = ctrl_net_lllites + + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + emb_normalize_mode: str = "original", + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + regional_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + self.is_sdxl, + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + emb_normalize_mode=emb_normalize_mode, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + self.is_sdxl, + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + emb_normalize_mode=emb_normalize_mode, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_net_lllites: + # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + if self.is_sdxl: + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if regional_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + if self.control_net_lllites: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_net_lllites: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_net_lllites: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_net_lllites) + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.warning("gradual_latent is not supported for this scheduler. Ignoring.") + logger.warning(f"{self.scheduler.__class__.__name__}") + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo + if self.control_net_lllites: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + if self.control_nets and self.control_net_enabled: + if regional_network: + num_sub_and_neg_prompts = len(text_embeddings) // batch_size + text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + else: + text_emb_last = text_embeddings + + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + text_emb_last, + ).sample + elif self.is_sdxl: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return latents + + latents = 1 / (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + logger.info(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + is_sdxl: bool, + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + if not is_sdxl: # SD 1.5 requires final_layer_norm + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + is_sdxl: bool, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip: int = 1, + token_replacer=None, + device=None, + emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + is_sdxl, + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + + if (not skip_parsing) and (not skip_weighting): + if emb_normalize_mode == "abs": + previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + elif emb_normalize_mode == "none": + text_embeddings *= prompt_weights.unsqueeze(-1) + if uncond_prompt is not None: + uncond_embeddings *= uncond_weights.unsqueeze(-1) + + else: # "original" + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + logger.warning(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + raw_prompt: str + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +class ListPrompter: + def __init__(self, prompts: List[str]): + self.prompts = prompts + self.index = 0 + + def shuffle(self): + random.shuffle(self.prompts) + + def __len__(self): + return len(self.prompts) + + def __call__(self, *args, **kwargs): + if self.index >= len(self.prompts): + self.index = 0 # reset + return None + + prompt = self.prompts[self.index] + self.index += 1 + return prompt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + if args.v_parameterization and not args.v2: + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + # モデルを読み込む + if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt + use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + + # SDXLかどうかを判定する + is_sdxl = args.sdxl + if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する + if use_stable_diffusion_format: + # if file size > 5.5GB, sdxl + is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3 + else: + # if `text_encoder_2` subdirectory exists, sdxl + is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) + logger.info(f"SDXL: {is_sdxl}") + + if is_sdxl: + if args.clip_skip is None: + args.clip_skip = 2 + + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + text_encoders = [text_encoder1, text_encoder2] + else: + if args.clip_skip is None: + args.clip_skip = 2 if args.v2 else 1 + + if use_stable_diffusion_format: + logger.info("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + logger.info("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe + + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) + text_encoders = [text_encoder] + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + logger.info("additional VAE loaded") + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + logger.info("loading tokenizer") + if is_sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteSchedulerGL + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + logger.info("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + vae.eval() + + for text_encoder in text_encoders: + text_encoder.to(dtype).to(device) + text_encoder.eval() + unet.to(dtype).to(device) + unet.eval() + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + logger.info(f"network_merge: {network_merge}") + + for i, network_module in enumerate(args.network_module): + logger.info("import network module: {network_module}") + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + logger.info(f"load network weights from: {network_weight}") + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + logger.info(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs + ) + if network is None: + return + + mergeable = network.is_mergeable() + if network_merge and not mergeable: + logger.warning("network is not mergiable. ignore merge option.") + + if not mergeable or i >= network_merge: + # not merging + network.apply_to(text_encoders, unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + logger.info(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + logger.info("backup original weights") + network.backup_weights() + + networks.append(network) + network_default_muls.append(network_mul) + else: + network.merge_to(text_encoders, unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + logger.info("import upscaler module: {args.highres_fix_upscaler}") + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + logger.info("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + logger.info(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net_lllite = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net_lllite.apply_to() + control_net_lllite.load_state_dict(state_dict) + control_net_lllite.to(dtype).to(device) + control_net_lllite.set_batch_cond_only(False, False) + control_net_lllites.append((control_net_lllite, ratio)) + assert ( + len(control_nets) == 0 or len(control_net_lllites) == 0 + ), "ControlNet and ControlNet-LLLite cannot be used at the same time" + + if args.opt_channels_last: + logger.info(f"set optimizing: channels last") + for text_encoder in text_encoders: + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + + for cn in control_net_lllites: + cn.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + is_sdxl, + device, + vae, + text_encoders, + tokenizers, + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + pipe.set_control_net_lllites(control_net_lllites) + logger.info("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + if is_sdxl: + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + else: + embeds1 = next(iter(data.values())) + embeds2 = None + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizers[0].add_tokens(token_strings) + num_added_tokens2 = tokenizers[1].add_tokens(token_strings) if is_sdxl else 0 + assert num_added_tokens1 == num_vectors_per_token and ( + num_added_tokens2 == 0 or num_added_tokens2 == num_vectors_per_token + ), ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) + token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert not is_sdxl or ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizers[0]) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizers[0])}" + assert ( + not is_sdxl or len(tokenizers[1]) - 1 == token_ids2[-1] + ), f"token ids 2 is not end of tokenize: {len(tokenizers[1])}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + if is_sdxl: + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + if is_sdxl: + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoders[0].resize_token_embeddings(len(tokenizers[0])) + token_embeds1 = text_encoders[0].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + + if is_sdxl: + text_encoders[1].resize_token_embeddings(len(tokenizers[1])) + token_embeds2 = text_encoders[1].get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + prompt_list = None + if args.from_file is not None: + logger.info(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + prompter = ListPrompter(prompt_list) + + elif args.from_module is not None: + + def load_module_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + logger.info(f"reading prompts from module: {args.from_module}") + prompt_module = load_module_from_path("prompt_module", args.from_module) + + prompter = prompt_module.get_prompter(args, pipe, networks) + + elif args.prompt is not None: + prompter = ListPrompter([args.prompt]) + + else: + prompter = None # interactive mode + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + logger.info(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + logger.info(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + logger.info(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + logger.info(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + logger.info(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + logger.info(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and prompter is None and not args.interactive: + logger.info("get prompts from images' metadata") + prompt_list = [] + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + prompter = ListPrompter(prompt_list) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + logger.info(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + logger.info(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + logger.info("use mask as region") + + size = None + for i, network in enumerate(networks): + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: + np_mask = np.array(mask_images[0]) + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + logger.info(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) + guide_images = None + else: + guide_images = None + + # 新しい乱数生成器を作成する + if args.seed is not None: + if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1: + # 引数のseedをそのまま使う + def fixed_seed(*args, **kwargs): + return args.seed + + seed_random = SimpleNamespace(randint=fixed_seed) + else: + seed_random = random.Random(args.seed) + else: + seed_random = random.Random() + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 if is_sdxl else 512 + if args.H is None: + args.H = 1024 if is_sdxl else 512 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") + if args.iter_same_seed: + iter_seed = seed_random.randint(0, 2**32 - 1) + else: + iter_seed = None + + # shuffle prompt list + if args.shuffle_prompts: + prompter.shuffle() + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + logger.info("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + logger.info("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + raw_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets or control_net_lllites: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + # TODO バッチから ds_ratio を取り出すべき + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + logger.info("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + emb_normalize_mode=args.emb_normalize_mode, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + if is_sdxl: + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.warning( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while True: + if args.interactive: + # interactive + valid = False + while not valid: + logger.info("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step) + if raw_prompt is None: + break + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + length = len(prompter) if hasattr(prompter, "__len__") else 0 + logger.info(f"prompt {prompt_index+1}/{length}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if args.iter_same_seed: + seed = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = seed_random.randint(0, 2**32 - 1) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets or control_net_lllites: # 複数件の場合あり + c = max(len(control_nets), len(control_net_lllites)) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + logger.info("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + + parser.add_argument( + "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" + ) + parser.add_argument( + "--v1", action="store_true", help="load Stable Diffusion v1.x model / Stable Diffusion 1.xのモデルを読み込む" + ) + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", + ) + parser.add_argument( + "--from_module", + type=str, + default=None, + help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む", + ) + parser.add_argument( + "--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数" + ) + parser.add_argument( + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument( + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument( + "--shuffle_prompts", + action="store_true", + help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", + ) + parser.add_argument( + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument( + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", + ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="layer number from bottom to use in CLIP, default is 1 for SD1/2, 2 for SDXL " + + "/ CLIPの後ろからn層目の出力を使う(デフォルトはSD1/2の場合1、SDXLの場合2)", + ) + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--emb_normalize_mode", + type=str, + default="original", + choices=["original", "none", "abs"], + help="embedding normalization mode / embeddingの正規化モード", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", + ) + + parser.add_argument( + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", + ) + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", + ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847a6..2c40f1a06 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -64,17 +64,11 @@ import diffusers import numpy as np -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, @@ -107,8 +101,15 @@ from tools.original_control_net import ControlNetInfo from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -144,12 +145,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -157,7 +158,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -173,7 +174,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -229,7 +230,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -285,7 +286,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -454,6 +455,8 @@ def __init__( self.control_nets: List[ControlNetInfo] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids @@ -484,6 +487,14 @@ def add_token_replacement_XTI(self, target_token_id, rep_token_ids): def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): @@ -689,7 +700,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -771,11 +782,11 @@ def __call__( clip_text_input = prompt_tokens if clip_text_input.shape[1] > self.tokenizer.model_max_length: # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) + logger.info(f"trim text input {clip_text_input.shape}") clip_text_input = torch.cat( [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 ) - print("trimmed", clip_text_input.shape) + logger.info(f"trimmed {clip_text_input.shape}") for i, clip_prompt in enumerate(clip_prompts): if clip_prompt is not None: # clip_promptがあれば上書きする @@ -893,8 +904,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -958,7 +968,49 @@ def __call__( else: text_emb_last = text_embeddings + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1052,8 +1104,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -1540,7 +1591,9 @@ def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred image_embeddings = self.vgg16_feat_model(image)["feat"] # バッチサイズが複数だと正しく動くかわからない - loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + loss = ( + (image_embeddings - guide_embeddings) ** 2 + ).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので grads = -torch.autograd.grad(loss, latents)[0] if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -1704,7 +1757,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -1734,7 +1787,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -2046,7 +2099,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -2116,7 +2169,7 @@ def replacer(): # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -2131,6 +2184,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -2163,9 +2217,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") # モデルを読み込む if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う @@ -2175,10 +2229,10 @@ def main(args): use_stable_diffusion_format = os.path.isfile(args.ckpt) if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") + logger.info("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: - print("load Diffusers pretrained models") + logger.info("load Diffusers pretrained models") loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) text_encoder = loading_pipe.text_encoder vae = loading_pipe.vae @@ -2201,21 +2255,21 @@ def main(args): # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") # # 置換するCLIPを読み込む # if args.replace_clip_l14_336: # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") + # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") + logger.info("prepare clip model") clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) else: clip_model = None if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") + logger.info("prepare resnet model") vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) else: vgg16_model = None @@ -2227,7 +2281,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") if use_stable_diffusion_format: tokenizer = train_util.load_tokenizer(args) @@ -2250,7 +2304,7 @@ def main(args): scheduler_cls = EulerDiscreteScheduler scheduler_module = diffusers.schedulers.scheduling_euler_discrete elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": scheduler_cls = DPMSolverMultistepScheduler @@ -2286,7 +2340,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -2295,7 +2349,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -2326,11 +2380,11 @@ def __getattr__(self, item): # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") + logger.info("set clip_sample to True") scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -2383,7 +2437,7 @@ def __getattr__(self, item): network_merge = 0 for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -2401,7 +2455,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -2409,7 +2463,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs @@ -2419,20 +2473,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -2446,7 +2500,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -2455,7 +2509,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -2472,7 +2526,7 @@ def __getattr__(self, item): control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last) @@ -2504,7 +2558,7 @@ def __getattr__(self, item): args.vgg16_guidance_layer, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -2513,6 +2567,29 @@ def __getattr__(self, item): if args.ds_depth_1 is not None: unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -2534,7 +2611,9 @@ def __getattr__(self, item): embeds = next(iter(data.values())) if type(embeds) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + raise ValueError( + f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}" + ) num_vectors_per_token = embeds.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] @@ -2547,7 +2626,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") assert ( min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 ), f"token ids is not ordered" @@ -2606,7 +2685,7 @@ def __getattr__(self, item): ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") # if num_vectors_per_token > 1: pipe.add_token_replacement(token_ids[0], token_ids) @@ -2631,10 +2710,10 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -2660,7 +2739,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -2676,24 +2755,24 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") + logger.info("get prompts from images' meta data") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -2722,17 +2801,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -2757,14 +2836,16 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.info( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -2790,7 +2871,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # shuffle prompt list @@ -2806,7 +2887,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + 0.5) @@ -2832,7 +2913,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2878,13 +2959,14 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), ) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2915,11 +2997,16 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2983,7 +3070,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -3011,8 +3098,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -3028,6 +3115,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) if args.use_original_file_name and init_images is not None: if type(init_images) is list: @@ -3050,7 +3139,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.info( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -3063,7 +3154,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -3104,40 +3196,48 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ds_timesteps_2 = args.ds_timesteps_2 ds_ratio = args.ds_ratio + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -3146,25 +3246,25 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -3172,47 +3272,89 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.info(f"Exception in parsing / 解析エラー: {parg}") + logger.info(ex) # override Deep Shrink if ds_depth_1 is not None: @@ -3220,6 +3362,31 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): ds_depth_1 = args.ds_depth_1 or 3 unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + logger.info(f'{unsharp_params}') + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -3230,7 +3397,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.info("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seed = iter_seed @@ -3240,7 +3407,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -3256,7 +3423,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.info( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -3272,9 +3439,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): guide_image = guide_images[global_step % len(guide_images)] elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: if prev_image is None: - print("Generate 1st image without guide image.") + logger.info("Generate 1st image without guide image.") else: - print("Use previous image as guide image.") + logger.info("Use previous image as guide image.") guide_image = prev_image if regional_network: @@ -3287,7 +3454,9 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -3316,22 +3485,31 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + add_logging_arguments(parser) + + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -3343,7 +3521,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -3397,9 +3577,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -3435,25 +3620,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -3475,7 +3681,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -3516,7 +3724,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -3525,7 +3736,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -3533,7 +3746,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -3548,14 +3764,21 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + "--control_net_preps", + type=str, + default=None, + nargs="*", + help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", ) parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") parser.add_argument( @@ -3593,6 +3816,45 @@ def setup_parser() -> argparse.ArgumentParser: "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" ) + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + return parser @@ -3600,4 +3862,5 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/library/config_util.py b/library/config_util.py index 47868f3ba..ff4de0921 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -1,462 +1,513 @@ import argparse from dataclasses import ( - asdict, - dataclass, + asdict, + dataclass, ) import functools import random from textwrap import dedent, indent import json from pathlib import Path + # from toolz import curry from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, + List, + Optional, + Sequence, + Tuple, + Union, ) import toml import voluptuous from voluptuous import ( - Any, - ExactSequence, - MultipleInvalid, - Object, - Required, - Schema, + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, ) from transformers import CLIPTokenizer from . import train_util from .train_util import ( - DreamBoothSubset, - FineTuningSubset, - ControlNetSubset, - DreamBoothDataset, - FineTuningDataset, - ControlNetDataset, - DatasetGroup, + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, ) +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) + # TODO: inherit Params class in Subset, Dataset + @dataclass class BaseSubsetParams: - image_dir: Optional[str] = None - num_repeats: int = 1 - shuffle_caption: bool = False - caption_separator: str = ',', - keep_tokens: int = 0 - keep_tokens_separator: str = None, - color_aug: bool = False - flip_aug: bool = False - face_crop_aug_range: Optional[Tuple[float, float]] = None - random_crop: bool = False - caption_prefix: Optional[str] = None - caption_suffix: Optional[str] = None - caption_dropout_rate: float = 0.0 - caption_dropout_every_n_epochs: int = 0 - caption_tag_dropout_rate: float = 0.0 - token_warmup_min: int = 1 - token_warmup_step: float = 0 + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + @dataclass class DreamBoothSubsetParams(BaseSubsetParams): - is_reg: bool = False - class_tokens: Optional[str] = None - caption_extension: str = ".caption" + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + @dataclass class FineTuningSubsetParams(BaseSubsetParams): - metadata_file: Optional[str] = None + metadata_file: Optional[str] = None + @dataclass class ControlNetSubsetParams(BaseSubsetParams): - conditioning_data_dir: str = None - caption_extension: str = ".caption" + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class ControlNetDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: - params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + @dataclass class DatasetBlueprint: - is_dreambooth: bool - is_controlnet: bool - params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] - subsets: Sequence[SubsetBlueprint] + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + @dataclass class DatasetGroupBlueprint: - datasets: Sequence[DatasetBlueprint] + datasets: Sequence[DatasetBlueprint] + + @dataclass class Blueprint: - dataset_group: DatasetGroupBlueprint + dataset_group: DatasetGroupBlueprint class ConfigSanitizer: - # @curry - @staticmethod - def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: - Schema(ExactSequence([klass, klass]))(value) - return tuple(value) - - # @curry - @staticmethod - def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: - Schema(Any(klass, ExactSequence([klass, klass])))(value) - try: - Schema(klass)(value) - return (value, value) - except: - return ConfigSanitizer.__validate_and_convert_twodim(klass, value) - - # subset schema - SUBSET_ASCENDABLE_SCHEMA = { - "color_aug": bool, - "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), - "flip_aug": bool, - "num_repeats": int, - "random_crop": bool, - "shuffle_caption": bool, - "keep_tokens": int, - "keep_tokens_separator": str, - "token_warmup_min": int, - "token_warmup_step": Any(float,int), - "caption_prefix": str, - "caption_suffix": str, - } - # DO means DropOut - DO_SUBSET_ASCENDABLE_SCHEMA = { - "caption_dropout_every_n_epochs": int, - "caption_dropout_rate": Any(float, int), - "caption_tag_dropout_rate": Any(float, int), - } - # DB means DreamBooth - DB_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - "class_tokens": str, - } - DB_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - "is_reg": bool, - } - # FT means FineTuning - FT_SUBSET_DISTINCT_SCHEMA = { - Required("metadata_file"): str, - "image_dir": str, - } - CN_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - } - CN_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - Required("conditioning_data_dir"): str, - } - - # datasets schema - DATASET_ASCENDABLE_SCHEMA = { - "batch_size": int, - "bucket_no_upscale": bool, - "bucket_reso_steps": int, - "enable_bucket": bool, - "max_bucket_reso": int, - "min_bucket_reso": int, - "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), - } - - # options handled by argparse but not handled by user config - ARGPARSE_SPECIFIC_SCHEMA = { - "debug_dataset": bool, - "max_token_length": Any(None, int), - "prior_loss_weight": Any(float, int), - } - # for handling default None value of argparse - ARGPARSE_NULLABLE_OPTNAMES = [ - "face_crop_aug_range", - "resolution", - ] - # prepare map because option name may differ among argparse and user config - ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { - "train_batch_size": "batch_size", - "dataset_repeats": "num_repeats", - } - - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" - - self.db_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_DISTINCT_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.ft_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.FT_SUBSET_DISTINCT_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.cn_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_DISTINCT_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.db_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.db_subset_schema]}, - ) - - self.ft_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.ft_subset_schema]}, - ) - - self.cn_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.cn_subset_schema]}, - ) - - if support_dreambooth and support_finetuning: - def validate_flex_dataset(dataset_config: dict): - subsets_config = dataset_config.get("subsets", []) - - if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): - return Schema(self.cn_dataset_schema)(dataset_config) - # check dataset meets FT style - # NOTE: all FT subsets should have "metadata_file" - elif all(["metadata_file" in subset for subset in subsets_config]): - return Schema(self.ft_dataset_schema)(dataset_config) - # check dataset meets DB style - # NOTE: all DB subsets should have no "metadata_file" - elif all(["metadata_file" not in subset for subset in subsets_config]): - return Schema(self.db_dataset_schema)(dataset_config) - else: - raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") - - self.dataset_schema = validate_flex_dataset - elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema - elif support_finetuning: - self.dataset_schema = self.ft_dataset_schema - elif support_controlnet: - self.dataset_schema = self.cn_dataset_schema - - self.general_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, - self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.user_config_validator = Schema({ - "general": self.general_schema, - "datasets": [self.dataset_schema], - }) - - self.argparse_schema = self.__merge_dict( - self.general_schema, - self.ARGPARSE_SPECIFIC_SCHEMA, - {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, - {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, - ) - - self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) - - def sanitize_user_config(self, user_config: dict) -> dict: - try: - return self.user_config_validator(user_config) - except MultipleInvalid: - # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") - raise - - # NOTE: In nature, argument parser result is not needed to be sanitize - # However this will help us to detect program bug - def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: - try: - return self.argparse_config_validator(argparse_namespace) - except MultipleInvalid: - # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") - raise - - # NOTE: value would be overwritten by latter dict if there is already the same key - @staticmethod - def __merge_dict(*dict_list: dict) -> dict: - merged = {} - for schema in dict_list: - # merged |= schema - for k, v in schema.items(): - merged[k] = v - return merged + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "secondary_separator": str, + "enable_wildcard": bool, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert ( + support_dreambooth or support_finetuning or support_controlnet + ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged class BlueprintGenerator: - BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { - } - - def __init__(self, sanitizer: ConfigSanitizer): - self.sanitizer = sanitizer - - # runtime_params is for parameters which is only configurable on runtime, such as tokenizer - def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: - sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) - sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) - - # convert argparse namespace to dict like config - # NOTE: it is ok to have extra entries in dict - optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME - argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} - - general_config = sanitized_user_config.get("general", {}) - - dataset_blueprints = [] - for dataset_config in sanitized_user_config.get("datasets", []): - # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets - subsets = dataset_config.get("subsets", []) - is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) - if is_controlnet: - subset_params_klass = ControlNetSubsetParams - dataset_params_klass = ControlNetDatasetParams - elif is_dreambooth: - subset_params_klass = DreamBoothSubsetParams - dataset_params_klass = DreamBoothDatasetParams - else: - subset_params_klass = FineTuningSubsetParams - dataset_params_klass = FineTuningDatasetParams - - subset_blueprints = [] - for subset_config in subsets: - params = self.generate_params_by_fallbacks(subset_params_klass, - [subset_config, dataset_config, general_config, argparse_config, runtime_params]) - subset_blueprints.append(SubsetBlueprint(params)) - - params = self.generate_params_by_fallbacks(dataset_params_klass, - [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) - - dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) - - return Blueprint(dataset_group_blueprint) - - @staticmethod - def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): - name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME - search_value = BlueprintGenerator.search_value - default_params = asdict(param_klass()) - param_names = default_params.keys() - - params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} - - return param_klass(**params) - - @staticmethod - def search_value(key: str, fallbacks: Sequence[dict], default_value = None): - for cand in fallbacks: - value = cand.get(key) - if value is not None: - return value - - return default_value + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - - for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset - elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset - else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) - datasets.append(dataset) - - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ [Dataset {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} - """) - - if dataset.enable_bucket: - info += indent(dedent(f"""\ + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") - else: - info += "\n" + \n""" + ), + " ", + ) + else: + info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} @@ -464,6 +515,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} + secondary_separator: {subset.secondary_separator} + enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} @@ -475,147 +528,179 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: - info += indent(dedent(f"""\ + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ metadata_file: {subset.metadata_file} - \n"""), " ") + \n""" + ), + " ", + ) - print(info) + logger.info(f"{info}") - # make buckets first because it determines the length of dataset - # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no - for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") - dataset.make_buckets() - dataset.set_seed(seed) + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + logger.info(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) - return DatasetGroup(datasets) + return DatasetGroup(datasets) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): - def extract_dreambooth_params(name: str) -> Tuple[int, str]: - tokens = name.split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") - return 0, "" - caption_by_folder = '_'.join(tokens[1:]) - return n_repeats, caption_by_folder - - def generate(base_dir: Optional[str], is_reg: bool): - if base_dir is None: - return [] - - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config subsets_config = [] - for subdir in base_dir.iterdir(): - if not subdir.is_dir(): - continue - - num_repeats, class_tokens = extract_dreambooth_params(subdir.name) - if num_repeats < 1: - continue - - subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir, False) - subsets_config += generate(reg_data_dir, True) - return subsets_config +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] -def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): - def generate(base_dir: Optional[str]): - if base_dir is None: - return [] + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + return subsets_config subsets_config = [] - subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir) - return subsets_config +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") + + if file.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file.name.lower().endswith(".toml"): + try: + config = toml.load(file) + except Exception: + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + return config -def load_user_config(file: str) -> dict: - file: Path = Path(file) - if not file.is_file(): - raise ValueError(f"file not found / ファイルが見つかりません: {file}") - - if file.name.lower().endswith('.json'): - try: - with open(file, 'r') as f: - config = json.load(f) - except Exception: - print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - elif file.name.lower().endswith('.toml'): - try: - config = toml.load(file) - except Exception: - print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - else: - raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") - - return config # for config test if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--support_dreambooth", action="store_true") - parser.add_argument("--support_finetuning", action="store_true") - parser.add_argument("--support_controlnet", action="store_true") - parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("dataset_config") - config_args, remain = parser.parse_known_args() - - parser = argparse.ArgumentParser() - train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) - train_util.add_training_arguments(parser, config_args.support_dreambooth) - argparse_namespace = parser.parse_args(remain) - train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + logger.info("[argparse_namespace]") + logger.info(f"{vars(argparse_namespace)}") - user_config = load_user_config(config_args.dataset_config) + user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + logger.info("") + logger.info("[user_config]") + logger.info(f"{user_config}") - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) - sanitized_user_config = sanitizer.sanitize_user_config(user_config) + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f"{sanitized_user_config}") - blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + logger.info("") + logger.info("[blueprint]") + logger.info(f"{blueprint}") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index e0a026dae..a56474622 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -3,7 +3,10 @@ import random import re from typing import List, Optional, Union - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): @@ -21,7 +24,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR - print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") def enforce_zero_terminal_snr(betas): # Convert betas to alphas_bar_sqrt @@ -49,8 +52,8 @@ def enforce_zero_terminal_snr(betas): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - # print("original:", noise_scheduler.betas) - # print("fixed:", betas) + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") noise_scheduler.betas = betas noise_scheduler.alphas = alphas @@ -79,13 +82,13 @@ def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info - # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") return scale def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): scale = get_snr_scale(timesteps, noise_scheduler) - # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss @@ -268,7 +271,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): tokens.append(text_token) weights.append(text_weight) if truncated: - print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 000000000..8823c5d9a --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,84 @@ +import functools +import gc + +import torch + +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + +try: + import intel_extension_for_pytorch as ipex # noqa + + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + + +def clean_memory(): + gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device + + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from library.ipex import ipex_init + + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 376fdb1e6..57b19d982 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -4,7 +4,10 @@ import argparse import os from library.utils import fire_in_thread - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( @@ -33,9 +36,9 @@ def upload( try: api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので - print("===========================================") - print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) @@ -56,9 +59,9 @@ def uploader(): path_in_repo=path_in_repo, ) except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので - print("===========================================") - print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") - print("===========================================") + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 333504935..972a3bf63 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -9,162 +9,171 @@ def ipex_init(): # pylint: disable=too-many-statements try: - # Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - # Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - # RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - # AMP: - torch.cuda.amp = torch.xpu.amp - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 - # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype(): - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e98807a84..8253c5b17 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None): ) else: return original_torch_bmm(input, mat2, out=out) + torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo ) else: return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + torch.xpu.synchronize(query.device) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 47b0375ae..732a18568 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -149,6 +149,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] @@ -283,6 +284,7 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, hidden_states[start_idx:end_idx] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b6d246dd2..b1b9ccf0e 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,17 +1,22 @@ -import contextlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np + +device_supports_fp64 = torch.xpu.has_fp64_dtype() # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: - print("IPEX backend doesn't support DataParallel on multiple XPU devices") + logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() @property def is_cuda(self): @@ -25,15 +30,17 @@ def return_xpu(device): # Autocast -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) # Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device @@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: return original_from_numpy(ndarray.astype('float32')) else: return original_from_numpy(ndarray) -if torch.xpu.has_fp64_dtype(): +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -66,20 +87,25 @@ def from_numpy(ndarray): # Data Type Errors: +@wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) +@wraps(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): # A1111 BF16 original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1 # Training original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) def functional_linear(input, weight, bias=None): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None): return original_functional_linear(input, weight, bias=bias) original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, # A1111 Embedding BF16 original_torch_cat = torch.cat +@wraps(torch.cat) def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) @@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs): # SwinIR BF16: original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) @@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor -def torch_tensor(*args, device=None, **kwargs): +@wraps(torch.tensor) +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_to(self, return_xpu(device), *args, **kwargs) @@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs): return original_Tensor_to(self, device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) @@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs): return original_Tensor_cuda(self, device, *args, **kwargs) original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): if check_device(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) @@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs): return original_UntypedStorage_init(*args, device=device, **kwargs) original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) def UntypedStorage_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) @@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs): return original_UntypedStorage_cuda(self, device, *args, **kwargs) original_torch_empty = torch.empty +@wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): if check_device(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) @@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs): return original_torch_empty(*args, device=device, **kwargs) original_torch_randn = torch.randn +@wraps(torch.randn) def torch_randn(*args, device=None, **kwargs): if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs): return original_torch_randn(*args, device=device, **kwargs) original_torch_ones = torch.ones +@wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): if check_device(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) @@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs): return original_torch_ones(*args, device=device, **kwargs) original_torch_zeros = torch.zeros +@wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): if check_device(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) @@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs): return original_torch_zeros(*args, device=device, **kwargs) original_torch_linspace = torch.linspace +@wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): if check_device(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) @@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs): return original_torch_linspace(*args, device=device, **kwargs) original_torch_Generator = torch.Generator +@wraps(torch.Generator) def torch_Generator(device=None): if check_device(device): return original_torch_Generator(return_xpu(device)) @@ -208,12 +255,14 @@ def torch_Generator(device=None): return original_torch_Generator(device) original_torch_load = torch.load +@wraps(torch.load) def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): if check_device(map_location): return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -232,7 +281,7 @@ def ipex_hijacks(): torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda - torch.autocast = ipex_autocast + torch.amp.autocast_mode.autocast.__init__ = autocast_init torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm @@ -244,5 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy + torch.as_tensor = as_tensor diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3963e9b15..5717233d4 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -17,7 +17,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.utils import logging - try: from diffusers.utils import PIL_INTERPOLATION except ImportError: @@ -626,7 +625,7 @@ def check_inputs(self, prompt, height, width, strength, callback_steps): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - print(height, width) + logger.info(f'{height} {width}') raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( diff --git a/library/model_util.py b/library/model_util.py index 1f40ce324..be410a026 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -3,22 +3,20 @@ import math import os -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex +init_ipex() - ipex_init() -except Exception: - pass import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 @@ -950,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict): for k, v in new_state_dict.items(): for weight_name in weights_to_convert: if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict @@ -1008,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt unet = UNet2DConditionModel(**unet_config).to(device) info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) + logger.info(f"loading u-net: {info}") # Convert the VAE model. vae_config = create_vae_diffusers_config() @@ -1016,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt vae = AutoencoderKL(**vae_config).to(device) info = vae.load_state_dict(converted_vae_checkpoint) - print("loading vae:", info) + logger.info(f"loading vae: {info}") # convert text_model if v2: @@ -1050,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt # logging.set_verbosity_error() # don't show annoying warning # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # logging.set_verbosity_warning() - # print(f"config: {text_model.config}") + # logger.info(f"config: {text_model.config}") cfg = CLIPTextConfig( vocab_size=49408, hidden_size=768, @@ -1073,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt ) text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) + logger.info(f"loading text encoder: {info}") return text_model, vae, unet @@ -1148,7 +1146,7 @@ def convert_key(key): # 最後の層などを捏造するか if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") + logger.info("make dummy weights for resblock.23, text_projection and logit scale.") keys = list(new_sd.keys()) for key in keys: if key.startswith("transformer.resblocks.22."): @@ -1267,14 +1265,14 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") + logger.info(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): # Diffusers local/remote try: vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") + logger.error(f"exception occurs in loading vae: {e}") + logger.error("retry with subfolder='vae'") vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae @@ -1346,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) if __name__ == "__main__": resos = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) + logger.info(f"{len(resos)}") + logger.info(f"{resos}") aspect_ratios = [w / h for w, h in resos] - print(aspect_ratios) + logger.info(f"{aspect_ratios}") ars = set() for ar in aspect_ratios: if ar in ars: - print("error! duplicate ar:", ar) + logger.error(f"error! duplicate ar: {ar}") ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py index 030c5c9ec..e944ff22b 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -113,6 +113,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] @@ -1380,7 +1384,7 @@ def __init__( ): super().__init__() assert sample_size is not None, "sample_size must be specified" - print( + logger.info( f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" ) @@ -1514,7 +1518,7 @@ def set_use_sdpa(self, sdpa: bool) -> None: def set_gradient_checkpointing(self, value=False): modules = self.down_blocks + [self.mid_block] + self.up_blocks for module in modules: - print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1709,14 +1713,14 @@ def __call__(self, *args, **kwargs): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 472686ba4..a63bd82ec 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -5,6 +5,10 @@ import os from typing import List, Optional, Tuple, Union import safetensors +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) r""" # Metadata Example @@ -231,7 +235,7 @@ def build_metadata( # # assert all values are filled # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): - print(f"Internal error: some metadata values are None: {metadata}") + logger.error(f"Internal error: some metadata values are None: {metadata}") return metadata diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee4056..03b182566 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ def __call__( if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents): diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 08b90c393..f03f1bae5 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -7,7 +7,10 @@ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" @@ -131,7 +134,7 @@ def convert_key(key): # temporary workaround for text_projection.weight.weight for Playground-v2 if "text_projection.weight.weight" in new_sd: - print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] del new_sd["text_projection.weight.weight"] @@ -186,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty checkpoint = None # U-Net - print("building U-Net") + logger.info("building U-Net") with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() - print("loading U-Net from checkpoint") + logger.info("loading U-Net from checkpoint") unet_sd = {} for k in list(state_dict.keys()): if k.startswith("model.diffusion_model."): unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) - print("U-Net: ", info) + logger.info(f"U-Net: {info}") # Text Encoders - print("building text encoders") + logger.info("building text encoders") # Text Encoder 1 is same to Stability AI's SDXL text_model1_cfg = CLIPTextConfig( @@ -252,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty with init_empty_weights(): text_model2 = CLIPTextModelWithProjection(text_model2_cfg) - print("loading text encoders from checkpoint") + logger.info("loading text encoders from checkpoint") te1_sd = {} te2_sd = {} for k in list(state_dict.keys()): @@ -266,22 +269,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 - print("text encoder 1:", info1) + logger.info(f"text encoder 1: {info1}") converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 - print("text encoder 2:", info2) + logger.info(f"text encoder 2: {info2}") # prepare vae - print("building VAE") + logger.info("building VAE") vae_config = model_util.create_vae_diffusers_config() with init_empty_weights(): vae = AutoencoderKL(**vae_config) - print("loading VAE from checkpoint") + logger.info("loading VAE from checkpoint") converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) - print("VAE:", info) + logger.info(f"VAE: {info}") ckpt_info = (epoch, global_step) if epoch is not None else None return text_model1, text_model2, vae, unet, logit_scale, ckpt_info diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index babda8ec5..17c345a89 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -30,7 +30,12 @@ from torch import nn from torch.nn import functional as F from einops import rearrange +from .utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) IN_CHANNELS: int = 4 OUT_CHANNELS: int = 4 @@ -332,7 +337,7 @@ def forward_body(self, x, emb): def forward(self, x, emb): if self.training and self.gradient_checkpointing: - # print("ResnetBlock2D: gradient_checkpointing") + # logger.info("ResnetBlock2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -366,7 +371,7 @@ def forward_body(self, hidden_states): def forward(self, hidden_states): if self.training and self.gradient_checkpointing: - # print("Downsample2D: gradient_checkpointing") + # logger.info("Downsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -653,7 +658,7 @@ def forward_body(self, hidden_states, context=None, timestep=None): def forward(self, hidden_states, context=None, timestep=None): if self.training and self.gradient_checkpointing: - # print("BasicTransformerBlock: checkpointing") + # logger.info("BasicTransformerBlock: checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -796,7 +801,7 @@ def forward_body(self, hidden_states, output_size=None): def forward(self, hidden_states, output_size=None): if self.training and self.gradient_checkpointing: - # print("Upsample2D: gradient_checkpointing") + # logger.info("Upsample2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): @@ -1046,7 +1051,7 @@ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> N for block in blocks: for module in block: if hasattr(module, "set_use_memory_efficient_attention"): - # print(module.__class__.__name__) + # logger.info(module.__class__.__name__) module.set_use_memory_efficient_attention(xformers, mem_eff) def set_use_sdpa(self, sdpa: bool) -> None: @@ -1061,7 +1066,7 @@ def set_gradient_checkpointing(self, value=False): for block in blocks: for module in block.modules(): if hasattr(module, "gradient_checkpointing"): - # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") module.gradient_checkpointing = value # endregion @@ -1071,7 +1076,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): timesteps = timesteps.expand(x.shape[0]) hs = [] - t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) + t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False) t_emb = t_emb.to(x.dtype) emb = self.time_embed(t_emb) @@ -1083,7 +1088,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def call_module(module, h, emb, context): x = h for layer in module: - # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) if isinstance(layer, ResnetBlock2D): x = layer(x, emb) elif isinstance(layer, Transformer2DModel): @@ -1129,20 +1134,20 @@ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): # call original model's methods def __getattr__(self, name): return getattr(self.delegate, name) - + def __call__(self, *args, **kwargs): return self.delegate(*args, **kwargs) def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): if ds_depth_1 is None: - print("Deep Shrink is disabled.") + logger.info("Deep Shrink is disabled.") self.ds_depth_1 = None self.ds_timesteps_1 = None self.ds_depth_2 = None self.ds_timesteps_2 = None self.ds_ratio = None else: - print( + logger.info( f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" ) self.ds_depth_1 = ds_depth_1 @@ -1161,7 +1166,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): timesteps = timesteps.expand(x.shape[0]) hs = [] - t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) + t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False) t_emb = t_emb.to(x.dtype) emb = _self.time_embed(t_emb) @@ -1229,7 +1234,7 @@ def call_module(module, h, emb, context): if __name__ == "__main__": import time - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModel() unet.to("cuda") @@ -1238,7 +1243,7 @@ def call_module(module, h, emb, context): unet.train() # 使用メモリ量確認用の疑似学習ループ - print("preparing optimizer") + logger.info("preparing optimizer") # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working @@ -1253,12 +1258,12 @@ def call_module(module, h, emb, context): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") if step == 1: time_start = time.perf_counter() @@ -1278,4 +1283,4 @@ def call_module(module, h, emb, context): optimizer.zero_grad(set_to_none=True) time_end = time.perf_counter() - print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ad748d15..1932bf881 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -1,14 +1,21 @@ import argparse -import gc import math import os from typing import Optional + import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -21,7 +28,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") ( load_stable_diffusion_format, @@ -47,8 +54,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info @@ -62,7 +68,7 @@ def _load_target_model( load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") ( text_encoder1, text_encoder2, @@ -76,7 +82,7 @@ def _load_target_model( from diffusers import StableDiffusionXLPipeline variant = "fp16" if weight_dtype == torch.float16 else None - print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: pipe = StableDiffusionXLPipeline.from_pretrained( @@ -84,12 +90,12 @@ def _load_target_model( ) except EnvironmentError as ex: if variant is not None: - print("try to load fp32 model") + logger.info("try to load fp32 model") pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) else: raise ex except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -112,7 +118,7 @@ def _load_target_model( with init_empty_weights(): unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") logit_scale = None ckpt_info = None @@ -120,13 +126,13 @@ def _load_target_model( # VAEを読み込む if vae_path is not None: vae = model_util.load_vae(vae_path, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info def load_tokenizers(args: argparse.Namespace): - print("prepare tokenizers") + logger.info("prepare tokenizers") original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] tokeniers = [] @@ -135,14 +141,14 @@ def load_tokenizers(args: argparse.Namespace): if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) if tokenizer is None: tokenizer = CLIPTokenizer.from_pretrained(original_path) if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) if i == 1: @@ -151,7 +157,7 @@ def load_tokenizers(args: argparse.Namespace): tokeniers.append(tokenizer) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") return tokeniers @@ -332,23 +338,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.v_parameterization: - print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: - print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") # if args.multires_noise_iterations: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # ) # else: # if args.noise_offset is None: # args.noise_offset = DEFAULT_NOISE_OFFSET # elif args.noise_offset != DEFAULT_NOISE_OFFSET: - # print( + # logger.info( # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # ) - # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( not hasattr(args, "weighted_captions") or not args.weighted_captions @@ -357,7 +363,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True - print( + logger.warning( "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" ) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 5c4e056d3..ea7653429 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -26,7 +26,10 @@ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.autoencoder_kl import AutoencoderKLOutput - +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def slice_h(x, num_slices): # slice with pad 1 both sides: to eliminate side effect of padding of conv2d @@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): # sliced_tensor = torch.chunk(x, num_div, dim=1) # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) - # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) # normed_tensor = [] # for i in range(num_div): # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) @@ -243,7 +246,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす - # print(f"initial divisor: {div}") + # logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -253,11 +256,11 @@ def forward(*args, **kwargs): for i, down_block in enumerate(self.down_blocks[::-1]): if div >= 2: div = int(div) - # print(f"down block: {i} divisor: {div}") + # logger.info(f"down block: {i} divisor: {div}") for resnet in down_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if down_block.downsamplers is not None: - # print("has downsample") + # logger.info("has downsample") for downsample in down_block.downsamplers: downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) div *= 2 @@ -307,7 +310,7 @@ def forward(self, x): def downsample_forward(self, _self, num_slices, hidden_states): assert hidden_states.shape[1] == _self.channels assert _self.use_conv and _self.padding == 0 - print("downsample forward", num_slices, hidden_states.shape) + logger.info(f"downsample forward {num_slices} {hidden_states.shape}") org_device = hidden_states.device cpu_device = torch.device("cpu") @@ -350,7 +353,7 @@ def downsample_forward(self, _self, num_slices, hidden_states): hidden_states = torch.cat([hidden_states, x], dim=2) hidden_states = hidden_states.to(org_device) - # print("downsample forward done", hidden_states.shape) + # logger.info(f"downsample forward done {hidden_states.shape}") return hidden_states @@ -426,7 +429,7 @@ def forward(*args, **kwargs): self.num_slices = num_slices div = num_slices / (2 ** (len(self.up_blocks) - 1)) - print(f"initial divisor: {div}") + logger.info(f"initial divisor: {div}") if div >= 2: div = int(div) for resnet in self.mid_block.resnets: @@ -436,11 +439,11 @@ def forward(*args, **kwargs): for i, up_block in enumerate(self.up_blocks): if div >= 2: div = int(div) - # print(f"up block: {i} divisor: {div}") + # logger.info(f"up block: {i} divisor: {div}") for resnet in up_block.resnets: resnet.forward = wrapper(resblock_forward, resnet, div) if up_block.upsamplers is not None: - # print("has upsample") + # logger.info("has upsample") for upsample in up_block.upsamplers: upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) div *= 2 @@ -528,7 +531,7 @@ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): del x hidden_states = torch.cat(sliced, dim=2) - # print("us hidden_states", hidden_states.shape) + # logger.info(f"us hidden_states {hidden_states.shape}") del sliced hidden_states = hidden_states.to(org_device) diff --git a/library/train_util.py b/library/train_util.py index ff161feab..99aeea90d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6,6 +6,7 @@ import datetime import importlib import json +import logging import pathlib import re import shutil @@ -19,8 +20,7 @@ Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs -import gc +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import glob import math import os @@ -31,7 +31,12 @@ import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms @@ -64,7 +69,12 @@ import library.model_util as model_util import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel @@ -73,6 +83,8 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +HIGH_VRAM = False + # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" EPOCH_FILE_NAME = "{}-{:06d}" @@ -211,7 +223,7 @@ def add_if_new_reso(self, reso): self.reso_to_id[reso] = bucket_id self.resos.append(reso) self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) + # logger.info(reso, bucket_id, len(self.buckets)) def round_to_steps(self, x): x = int(x + 0.5) @@ -237,7 +249,7 @@ def select_bucket(self, image_width, image_height): scale = reso[0] / image_width resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) - # print("use predef", image_width, image_height, reso, resized_size) + # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") else: # 縮小のみを行う if image_width * image_height > self.max_area: @@ -256,21 +268,21 @@ def select_bucket(self, image_width, image_height): b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) ar_height_rounded = b_width_in_hr / b_height_rounded - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) + # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) else: resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) - # print(resized_size) + # logger.info(resized_size) else: resized_size = (image_width, image_height) # リサイズは不要 # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) bucket_width = resized_size[0] - resized_size[0] % self.reso_steps bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") reso = (bucket_width, bucket_height) @@ -352,6 +364,8 @@ def __init__( caption_separator: str, keep_tokens: int, keep_tokens_separator: str, + secondary_separator: Optional[str], + enable_wildcard: bool, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -370,6 +384,8 @@ def __init__( self.caption_separator = caption_separator self.keep_tokens = keep_tokens self.keep_tokens_separator = keep_tokens_separator + self.secondary_separator = secondary_separator + self.enable_wildcard = enable_wildcard self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -398,6 +414,8 @@ def __init__( caption_separator: str, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -419,6 +437,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -454,6 +474,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -475,6 +497,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -507,6 +531,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -528,6 +554,8 @@ def __init__( caption_separator, keep_tokens, keep_tokens_separator, + secondary_separator, + enable_wildcard, color_aug, flip_aug, face_crop_aug_range, @@ -558,6 +586,7 @@ def __init__( tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], max_token_length: int, resolution: Optional[Tuple[int, int]], + network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() @@ -567,6 +596,7 @@ def __init__( self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] @@ -661,15 +691,48 @@ def process_caption(self, subset: BaseSubset, caption): if is_drop_out: caption = "" else: + # process wildcards + if subset.enable_wildcard: + # if caption is multiline, random choice one line + if "\n" in caption: + caption = random.choice(caption.split("\n")) + + # wildcard is like '{aaa|bbb|ccc...}' + # escape the curly braces like {{ or }} + replacer1 = "⦅" + replacer2 = "⦆" + while replacer1 in caption or replacer2 in caption: + replacer1 += "⦅" + replacer2 += "⦆" + + caption = caption.replace("{{", replacer1).replace("}}", replacer2) + + # replace the wildcard + def replace_wildcard(match): + return random.choice(match.group(1).split("|")) + + caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption) + + # unescape the curly braces + caption = caption.replace(replacer1, "{").replace(replacer2, "}") + else: + # if caption is multiline, use the first line + caption = caption.split("\n")[0] + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: fixed_tokens = [] flex_tokens = [] + fixed_suffix_tokens = [] if ( hasattr(subset, "keep_tokens_separator") and subset.keep_tokens_separator and subset.keep_tokens_separator in caption ): fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + if subset.keep_tokens_separator in flex_part: + flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1) + fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()] + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] else: @@ -704,7 +767,11 @@ def dropout_tags(tokens): flex_tokens = dropout_tags(flex_tokens) - caption = ", ".join(fixed_tokens + flex_tokens) + caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens) + + # process secondary separator + if subset.secondary_separator: + caption = caption.replace(subset.secondary_separator, subset.caption_separator) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -777,15 +844,15 @@ def make_buckets(self): bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) min_size and max_size are ignored when enable_bucket is False """ - print("loading image sizes.") + logger.info("loading image sizes.") for info in tqdm(self.image_data.values()): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) if self.enable_bucket: - print("make buckets") + logger.info("make buckets") else: - print("prepare dataset") + logger.info("prepare dataset") # bucketを作成し、画像をbucketに振り分ける if self.enable_bucket: @@ -800,7 +867,7 @@ def make_buckets(self): if not self.bucket_no_upscale: self.bucket_manager.make_buckets() else: - print( + logger.warning( "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" ) @@ -811,7 +878,7 @@ def make_buckets(self): image_width, image_height ) - # print(image_info.image_key, image_info.bucket_reso) + # logger.info(image_info.image_key, image_info.bucket_reso) img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() @@ -829,17 +896,17 @@ def make_buckets(self): # bucket情報を表示、格納する if self.enable_bucket: self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): count = len(bucket) if count > 0: self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") img_ar_errors = np.array(img_ar_errors) mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") + logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる self.buckets_indices: List(BucketBatchIndex) = [] @@ -859,7 +926,7 @@ def make_buckets(self): # num_of_image_types = len(set(bucket)) # bucket_batch_size = min(self.batch_size, num_of_image_types) # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) # for batch_index in range(batch_count): # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) # ↑ここまで @@ -898,7 +965,7 @@ def is_text_encoder_output_cacheable(self): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching latents.") + logger.info("caching latents.") image_infos = list(self.image_data.values()) @@ -908,7 +975,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # split by resolution batches = [] batch = [] - print("checking cache validity...") + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -945,7 +1012,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - print("caching latents...") + logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) @@ -959,10 +1026,10 @@ def cache_text_encoder_outputs( # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - print("caching text encoder outputs.") + logger.info("caching text encoder outputs.") image_infos = list(self.image_data.values()) - print("checking cache existence...") + logger.info("checking cache existence...") image_infos_to_cache = [] for info in tqdm(image_infos): # subset = self.image_to_subset[info.image_key] @@ -1003,7 +1070,7 @@ def cache_text_encoder_outputs( batches.append(batch) # iterate batches: call text encoder and cache outputs for memory or disk - print("caching text encoder outputs...") + logger.info("caching text encoder outputs...") for batch in tqdm(batches): infos, input_ids1, input_ids2 = zip(*batch) input_ids1 = torch.stack(input_ids1, dim=0) @@ -1106,7 +1173,9 @@ def __getitem__(self, index): for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + loss_weights.append( + self.prior_loss_weight if image_info.is_reg else 1.0 + ) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1272,6 +1341,8 @@ def __getitem__(self, index): example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) example["flippeds"] = flippeds + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1346,15 +1417,16 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1381,7 +1453,7 @@ def __init__( self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension): + def read_caption(img_path, caption_extension, enable_wildcard): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name @@ -1397,28 +1469,31 @@ def read_caption(img_path, caption_extension): try: lines = f.readlines() except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() break return caption def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") + logger.warning(f"not directory: {subset.image_dir}") return [], [] img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: - print( + logger.warning( f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" ) captions.append("") @@ -1437,36 +1512,38 @@ def load_dreambooth_dir(subset: DreamBoothSubset): number_of_missing_captions_to_show = 5 remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - print( + logger.warning( f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" ) for i, missing_caption in enumerate(missing_captions): if i >= number_of_missing_captions_to_show: - print(missing_caption + f"... and {remaining_missing_captions} more") + logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break - print(missing_caption) + logger.warning(missing_caption) return img_paths, captions - print("prepare images.") + logger.info("prepare images.") num_train_images = 0 num_reg_images = 0 - reg_infos: List[ImageInfo] = [] + reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue img_paths, captions = load_dreambooth_dir(subset) if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" + ) continue if subset.is_reg: @@ -1477,28 +1554,28 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) if subset.is_reg: - reg_infos.append(info) + reg_infos.append((info, subset)) else: self.register_image(info, subset) subset.img_count = len(img_paths) self.subsets.append(subset) - print(f"{num_train_images} train images with repeating.") + logger.info(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images - print(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") + logger.warning("no regularization images / 正則化画像が見つかりませんでした") else: # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 first_loop = True while n < num_train_images: - for info in reg_infos: + for info, subset in reg_infos: if first_loop: self.register_image(info, subset) n += info.num_repeats @@ -1520,14 +1597,15 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1536,27 +1614,29 @@ def __init__( for subset in subsets: if subset.num_repeats < 1: - print( + logger.warning( f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" ) continue if subset in self.subsets: - print( + logger.warning( f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" ) continue # メタデータを読み込む if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") + logger.info(f"loading existing metadata: {subset.metadata_file}") with open(subset.metadata_file, "rt", encoding="utf-8") as f: metadata = json.load(f) else: raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + logger.warning( + f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します" + ) continue tags_list = [] @@ -1587,10 +1667,24 @@ def __init__( caption = img_md.get("caption") tags = img_md.get("tags") if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ", " + tags - tags_list.append(tags) + caption = tags # could be multiline + tags = None + + if subset.enable_wildcard: + # tags must be single line + if tags is not None: + tags = tags.replace("\n", subset.caption_separator) + + # add tags to each line of caption + if caption is not None and tags is not None: + caption = "\n".join( + [f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""] + ) + else: + # use as is + if tags is not None and len(tags) > 0: + caption = caption + subset.caption_separator + tags + tags_list.append(tags) if caption is None: caption = "" @@ -1634,14 +1728,16 @@ def __init__( if not npz_any: use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") elif not npz_all: use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + logger.warning( + f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" + ) if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") # check min/max bucket size sizes = set() @@ -1657,7 +1753,9 @@ def __init__( if sizes is None: if use_npz_latents: use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + logger.warning( + f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" + ) assert ( resolution is not None @@ -1671,8 +1769,8 @@ def __init__( self.bucket_no_upscale = bucket_no_upscale else: if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") + logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") self.enable_bucket = True assert ( @@ -1724,14 +1822,15 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -1745,6 +1844,8 @@ def __init__( subset.caption_separator, subset.keep_tokens, subset.keep_tokens_separator, + subset.secondary_separator, + subset.enable_wildcard, subset.color_aug, subset.flip_aug, subset.face_crop_aug_range, @@ -1765,6 +1866,7 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier, enable_bucket, min_bucket_reso, max_bucket_reso, @@ -1793,7 +1895,7 @@ def __init__( assert subset is not None, "internal error: subset not found" if not os.path.isdir(subset.conditioning_data_dir): - print(f"not directory: {subset.conditioning_data_dir}") + logger.warning(f"not directory: {subset.conditioning_data_dir}") continue img_basename = os.path.basename(info.absolute_path) @@ -1851,7 +1953,9 @@ def __getitem__(self, index): assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + cond_img = cv2.resize( + cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + ) # INTER_AREAでやりたいのでcv2でリサイズ # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -1913,14 +2017,14 @@ def enable_XTI(self, *args, **kwargs): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def cache_text_encoder_outputs( self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True ): for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) def set_caching_mode(self, caching_mode): @@ -2004,12 +2108,15 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") + logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + logger.info( + "`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します" + ) epoch = 1 while True: - print(f"\nepoch: {epoch}") + logger.info(f"") + logger.info(f"epoch: {epoch}") steps = (epoch - 1) * len(train_dataset) + 1 indices = list(range(len(train_dataset))) @@ -2019,11 +2126,11 @@ def debug_dataset(train_dataset, show_input_ids=False): for i, idx in enumerate(indices): train_dataset.set_current_epoch(epoch) train_dataset.set_current_step(steps) - print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") example = train_dataset[idx] if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") + logger.info(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( zip( example["image_keys"], @@ -2036,24 +2143,26 @@ def debug_dataset(train_dataset, show_input_ids=False): example["flippeds"], ) ): - print( + logger.info( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) + if "network_multipliers" in example: + print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: - print(f"input ids: {iid}") + logger.info(f"input ids: {iid}") if "input_ids2" in example: - print(f"input ids2: {example['input_ids2'][j]}") + logger.info(f"input ids2: {example['input_ids2'][j]}") if example["images"] is not None: im = example["images"][j] - print(f"image size: {im.size()}") + logger.info(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV) if "conditioning_images" in example: cond_img = example["conditioning_images"][j] - print(f"conditioning image size: {cond_img.size()}") + logger.info(f"conditioning image size: {cond_img.size()}") cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) cond_img = np.transpose(cond_img, (1, 2, 0)) cond_img = cond_img[:, :, ::-1] @@ -2105,8 +2214,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2201,12 +2310,12 @@ def trim_and_resize_if_required( if image_width > reso[0]: trim_size = image_width - reso[0] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) + # logger.info(f"w {trim_size} {p}") image = image[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) + # logger.info(f"h {trim_size} {p}) image = image[p : p + reso[1]] # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない @@ -2266,9 +2375,8 @@ def cache_batch_latents( if flip_aug: info.latents_flipped = flipped_latent - # FIXME this slows down caching a lot, specify this as an option - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if not HIGH_VRAM: + clean_memory_on_device(vae.device) def cache_batch_text_encoder_outputs( @@ -2444,7 +2552,7 @@ def get_git_revision_hash() -> str: # def replace_unet_cross_attn_to_xformers(): -# print("CrossAttention.forward has been replaced to enable xformers.") +# logger.info("CrossAttention.forward has been replaced to enable xformers.") # try: # import xformers.ops # except ImportError: @@ -2487,10 +2595,10 @@ def get_git_revision_hash() -> str: # diffusers.models.attention.CrossAttention.forward = forward_xformers def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -2498,7 +2606,7 @@ def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdp unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_sdpa(True) @@ -2509,17 +2617,17 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform replace_vae_attn_to_memory_efficient() elif xformers: # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - print("Use Diffusers xformers for VAE") + logger.info("Use Diffusers xformers for VAE") vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) def replace_vae_attn_to_memory_efficient(): - print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states): - print("forward_flash_attn") + logger.info("forward_flash_attn") q_bucket_size = 512 k_bucket_size = 1024 @@ -2664,7 +2772,9 @@ def get_sai_model_spec( def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) parser.add_argument( "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" ) @@ -2704,7 +2814,10 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", ) parser.add_argument( @@ -2751,13 +2864,23 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument( - "--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名" + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" ) parser.add_argument( - "--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類" + "--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名" + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", ) parser.add_argument( "--huggingface_path_in_repo", @@ -2793,10 +2916,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="precision in saving / 保存時に精度を変更して保存する", ) parser.add_argument( - "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", ) parser.add_argument( - "--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する" + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", ) parser.add_argument( "--save_n_epoch_ratio", @@ -2831,7 +2960,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", + ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する", ) parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -2848,16 +2982,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) - parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") parser.add_argument( - "--dynamo_backend", - type=str, - default="inductor", + "--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う" + ) + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], - help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( @@ -2866,7 +3002,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", ) parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") @@ -2898,12 +3037,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") parser.add_argument( "--ddp_timeout", type=int, @@ -2939,7 +3083,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: choices=["tensorboard", "wandb", "all"], help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", ) - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) parser.add_argument( "--log_tracker_name", type=str, @@ -2970,6 +3116,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--noise_offset_random_strength", + action="store_true", + help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", + ) parser.add_argument( "--multires_noise_iterations", type=int, @@ -2983,6 +3134,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + action="store_true", + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -3022,13 +3179,24 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + parser.add_argument( + "--highvram", + action="store_true", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", ) parser.add_argument( - "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" ) - parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する") parser.add_argument( "--sample_every_n_epochs", type=int, @@ -3036,7 +3204,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", ) parser.add_argument( - "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", ) parser.add_argument( "--sample_sampler", @@ -3113,14 +3284,25 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): + r""" + Verify training arguments. Also reflect highvram option to global variable + 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する + """ + if args.highvram: + print("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + logger.warning( + "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" + ) if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True - print( + logger.warning( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) @@ -3151,18 +3333,34 @@ def verify_training_args(args: argparse.Namespace): ) if args.zero_terminal_snr and not args.v_parameterization: - print( + logger.warning( f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" ) + if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0: + logger.warning( + "sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_epochs = None + + if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0: + logger.warning( + "sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_steps = None + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする") + parser.add_argument( + "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" + ) + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" @@ -3186,6 +3384,18 @@ def add_dataset_arguments( help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", ) + parser.add_argument( + "--secondary_separator", + type=str, + default=None, + help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption" + + " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる", + ) + parser.add_argument( + "--enable_wildcard", + action="store_true", + help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする(例:'{image|picture|rendition}')", + ) parser.add_argument( "--caption_prefix", type=str, @@ -3198,8 +3408,12 @@ def add_dataset_arguments( default=None, help="suffix for caption text / captionのテキストの末尾に付ける文字列", ) - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument( + "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" + ) + parser.add_argument( + "--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする" + ) parser.add_argument( "--face_crop_aug_range", type=str, @@ -3212,7 +3426,9 @@ def add_dataset_arguments( help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", ) parser.add_argument( - "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" + "--debug_dataset", + action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)", ) parser.add_argument( "--resolution", @@ -3225,14 +3441,18 @@ def add_dataset_arguments( action="store_true", help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) - parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ" + ) parser.add_argument( "--cache_latents_to_disk", action="store_true", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", ) parser.add_argument( - "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" + "--enable_bucket", + action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") @@ -3243,7 +3463,9 @@ def add_dataset_arguments( help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( @@ -3287,13 +3509,20 @@ def add_dataset_arguments( if support_dreambooth: # DreamBooth dataset - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + parser.add_argument( + "--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ" + ) if support_caption: # caption dataset - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") parser.add_argument( - "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" + "--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル" + ) + parser.add_argument( + "--dataset_repeats", + type=int, + default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数", ) @@ -3321,7 +3550,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar if args.output_config: # check if config file exists if os.path.exists(config_path): - print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") exit(1) # convert args to dictionary @@ -3349,15 +3578,15 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar with open(config_path, "w") as f: toml.dump(args_dict, f) - print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") exit(0) if not os.path.exists(config_path): - print(f"{config_path} not found.") + logger.info(f"{config_path} not found.") exit(1) - print(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: + logger.info(f"Loading settings from {config_path}...") + with open(config_path, "r", encoding="utf-8") as f: config_dict = toml.load(f) # combine all sections into one @@ -3375,7 +3604,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - print(args.config_file) + logger.info(args.config_file) return args @@ -3390,11 +3619,11 @@ def resume_from_local_or_hf_if_specified(accelerator, args): return if not args.resume_from_huggingface: - print(f"resume training from local state: {args.resume}") + logger.info(f"resume training from local state: {args.resume}") accelerator.load_state(args.resume) return - print(f"resume training from huggingface state: {args.resume}") + logger.info(f"resume training from huggingface state: {args.resume}") repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] path_in_repo = "/".join(args.resume.split("/")[2:]) revision = None @@ -3406,7 +3635,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): repo_type = "model" else: path_in_repo, revision, repo_type = divided - print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") list_files = huggingface_util.list_dir( repo_id=repo_id, @@ -3431,7 +3660,9 @@ def task(): loop = asyncio.get_event_loop() results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) if len(results) == 0: - raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした") + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) dirname = os.path.dirname(results[0]) accelerator.load_state(dirname) @@ -3478,7 +3709,7 @@ def get_optimizer(args, trainable_params): # value = tuple(value) optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + # logger.info(f"optkwargs {optimizer}_{kwargs}") lr = args.learning_rate optimizer = None @@ -3488,7 +3719,7 @@ def get_optimizer(args, trainable_params): import lion_pytorch except ImportError: raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") + logger.info(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3499,14 +3730,14 @@ def get_optimizer(args, trainable_params): raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") if optimizer_type == "AdamW8bit".lower(): - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") optimizer_class = bnb.optim.AdamW8bit optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov8bit".lower(): - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print( + logger.warning( f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" ) optimizer_kwargs["momentum"] = 0.9 @@ -3515,7 +3746,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.Lion8bit except AttributeError: @@ -3523,7 +3754,7 @@ def get_optimizer(args, trainable_params): "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedAdamW8bit except AttributeError: @@ -3531,7 +3762,7 @@ def get_optimizer(args, trainable_params): "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") try: optimizer_class = bnb.optim.PagedLion8bit except AttributeError: @@ -3542,7 +3773,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): - print(f"use PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3556,7 +3787,7 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW32bit".lower(): - print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: import bitsandbytes as bnb except ImportError: @@ -3570,16 +3801,18 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + logger.info( + f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) optimizer_kwargs["momentum"] = 0.9 optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): - # check lr and lr_count, and print warning + # check lr and lr_count, and logger.info warning actual_lr = lr lr_count = 1 if type(trainable_params) == list and type(trainable_params[0]) == dict: @@ -3590,12 +3823,12 @@ def get_optimizer(args, trainable_params): lr_count = len(lrs) if actual_lr <= 0.1: - print( + logger.warning( f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) - print("recommend option: lr=1.0 / 推奨は1.0です") + logger.warning("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: - print( + logger.warning( f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) @@ -3611,25 +3844,25 @@ def get_optimizer(args, trainable_params): # set optimizer if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdam".lower(): optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdanIP".lower(): optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptLion".lower(): optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") @@ -3642,7 +3875,7 @@ def get_optimizer(args, trainable_params): except ImportError: raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - print(f"use Prodigy optimizer | {optimizer_kwargs}") + logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") optimizer_class = prodigyopt.Prodigy optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3651,14 +3884,16 @@ def get_optimizer(args, trainable_params): if "relative_step" not in optimizer_kwargs: optimizer_kwargs["relative_step"] = True # default if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") + logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") args.learning_rate = None # trainable_paramsがgroupだった時の処理:lrを削除する @@ -3670,37 +3905,37 @@ def get_optimizer(args, trainable_params): if has_group_lr: # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") args.unet_lr = None args.text_encoder_lr = None if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど lr = None else: if args.max_grad_norm != 0.0: - print( + logger.warning( f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" ) if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") optimizer_class = transformers.optimization.Adafactor optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") + logger.info(f"use {optimizer_type} | {optimizer_kwargs}") if "." not in optimizer_type: optimizer_module = torch.optim else: @@ -3746,7 +3981,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type - print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") if "." not in lr_scheduler_type: # default to use torch.optim lr_scheduler_module = torch.optim.lr_scheduler else: @@ -3762,7 +3997,7 @@ def wrap_check_needless_num_warmup_steps(return_vals): type(optimizer) == transformers.optimization.Adafactor ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) - # print("adafactor scheduler init lr", initial_lr) + # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) @@ -3827,20 +4062,20 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): if support_metadata: if args.in_json is not None and (args.color_aug or args.random_crop): - print( + logger.warning( f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") + logger.info("prepare tokenizer") original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH tokenizer: CLIPTokenizer = None if args.tokenizer_cache_dir: local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 if tokenizer is None: @@ -3850,10 +4085,10 @@ def load_tokenizer(args: argparse.Namespace): tokenizer = CLIPTokenizer.from_pretrained(original_path) if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + logger.info(f"update token length: {args.max_token_length}") if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") tokenizer.save_pretrained(local_tokenizer_path) return tokenizer @@ -3875,7 +4110,9 @@ def prepare_accelerator(args: argparse.Namespace): log_with = args.log_with if log_with in ["tensorboard", "all"]: if logging_dir is None: - raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) if log_with in ["wandb", "all"]: try: import wandb @@ -3886,7 +4123,7 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) - + # torch.compile のオプション。 NO の場合は torch.compile は使わない dynamo_backend = "NO" if args.torch_compile: @@ -3894,9 +4131,13 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None, + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) accelerator = Accelerator( @@ -3907,6 +4148,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, ) + print("accelerator device:", accelerator.device) return accelerator @@ -3933,17 +4175,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: - print(f"load StableDiffusion checkpoint: {name_or_path}") + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) else: # Diffusers model is loaded to CPU - print(f"load Diffusers pretrained models: {name_or_path}") + logger.info(f"load Diffusers pretrained models: {name_or_path}") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) except EnvironmentError as ex: - print( + logger.error( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) raise ex @@ -3954,7 +4196,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une # Diffusers U-Net to original U-Net # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう - # print(f"unet config: {unet.config}") + # logger.info(f"unet config: {unet.config}") original_unet = UNet2DConditionModel( unet.config.sample_size, unet.config.attention_head_dim, @@ -3964,12 +4206,12 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet - print("U-Net converted to original U-Net") + logger.info("U-Net converted to original U-Net") # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + logger.info("additional VAE loaded") return text_encoder, vae, unet, load_stable_diffusion_format @@ -3978,7 +4220,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( args, @@ -3993,8 +4235,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4045,7 +4286,9 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod # v1: ... の三連を ... へ戻す states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append( + encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] + ) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) @@ -4287,7 +4530,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( ckpt_name = get_step_ckpt_name(args, ext, global_step) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: @@ -4302,7 +4546,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) if os.path.exists(remove_ckpt_file): - print(f"removing old checkpoint: {remove_ckpt_file}") + logger.info(f"removing old checkpoint: {remove_ckpt_file}") os.remove(remove_ckpt_file) else: @@ -4311,7 +4555,8 @@ def save_sd_model_on_epoch_end_or_stepwise_common( else: out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - print(f"\nsaving model: {out_dir}") + logger.info("") + logger.info(f"saving model: {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4325,7 +4570,7 @@ def save_sd_model_on_epoch_end_or_stepwise_common( remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) if os.path.exists(remove_out_dir): - print(f"removing old model: {remove_out_dir}") + logger.info(f"removing old model: {remove_out_dir}") shutil.rmtree(remove_out_dir) if args.save_state: @@ -4338,13 +4583,14 @@ def save_sd_model_on_epoch_end_or_stepwise_common( def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - print(f"\nsaving state at epoch {epoch_no}") + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs @@ -4352,20 +4598,21 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - print(f"\nsaving state at step {step_no}") + logger.info("") + logger.info(f"saving state at step {step_no}") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading state to huggingface.") + logger.info("uploading state to huggingface.") huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps @@ -4377,21 +4624,22 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n if remove_step_no > 0: state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") + logger.info(f"removing old state: {state_dir_old}") shutil.rmtree(state_dir_old) def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - print("\nsaving last state.") + logger.info("") + logger.info("saving last state.") os.makedirs(args.output_dir, exist_ok=True) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - print("uploading last state to huggingface.") + logger.info("uploading last state to huggingface.") huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) @@ -4440,7 +4688,7 @@ def save_sd_model_on_train_end_common( ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") sd_saver(ckpt_file, epoch, global_step) if args.huggingface_repo_id is not None: @@ -4449,7 +4697,7 @@ def save_sd_model_on_train_end_common( out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") + logger.info(f"save trained model as Diffusers to {out_dir}") diffusers_saver(out_dir) if args.huggingface_repo_id is not None: @@ -4460,7 +4708,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount @@ -4477,7 +4729,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -4559,7 +4815,7 @@ def get_my_scheduler( # clip_sample=Trueにする if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") scheduler.config.clip_sample = True return scheduler @@ -4618,8 +4874,8 @@ def line_to_prompt_dict(line: str) -> dict: continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) return prompt_dict @@ -4641,6 +4897,7 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ + if steps == 0: if not args.sample_at_first: return @@ -4655,13 +4912,16 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") + logger.info("") + logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4671,10 +4931,6 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4687,12 +4943,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - schedulers: dict = {} + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) - schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, @@ -4705,105 +4960,58 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - - with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt_dict in enumerate(prompts): - if not accelerator.is_main_process: - continue - - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scheduler = schedulers.get(sampler_name) - if scheduler is None: - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - schedulers[sampler_name] = scheduler - pipeline.scheduler = scheduler - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") - if seed is not None: - print(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) - image = pipeline.latents_to_image(latents)[0] + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass - image.save(os.path.join(save_dir, img_filename)) + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet + ) # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + + # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. + # with torch.cuda.device(torch.cuda.current_device()): + # torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) if cuda_rng_state is not None: @@ -4811,8 +5019,105 @@ def sample_images_common( vae.to(org_vae_device) +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + pipeline, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement, + controlnet=None, +): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + + image = pipeline.latents_to_image(latents)[0] + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass + + # endregion + # region 前処理用 @@ -4831,7 +5136,7 @@ def __getitem__(self, idx): # convert to tensor temporarily so dataloader will accept it tensor_pil = transforms.functional.pil_to_tensor(image) except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") return None return (tensor_pil, img_path) diff --git a/library/utils.py b/library/utils.py index 7d801a676..3037c055d 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,266 @@ +import logging +import sys import threading +import torch +from torchvision import transforms from typing import * +from diffusers import EulerAncestralDiscreteScheduler +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +def add_logging_arguments(parser): + parser.add_argument( + "--console_log_level", + type=str, + default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", + ) + parser.add_argument( + "--console_log_file", + type=str, + default=None, + help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", + ) + parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") + + +def setup_logging(args=None, log_level=None, reset=False): + if logging.root.handlers: + if reset: + # remove all handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + else: + return + + # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO + if log_level is None and args is not None: + log_level = args.console_log_level + if log_level is None: + log_level = "INFO" + log_level = getattr(logging, log_level) + + msg_init = None + if args is not None and args.console_log_file: + handler = logging.FileHandler(args.console_log_file, mode="w") + else: + handler = None + if not args or not args.console_log_simple: + try: + from rich.logging import RichHandler + from rich.console import Console + from rich.logging import RichHandler + + handler = RichHandler(console=Console(stderr=True)) + except ImportError: + # print("rich is not installed, using basic logging") + msg_init = "rich is not installed, using basic logging" + + if handler is None: + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False + + formatter = logging.Formatter( + fmt="%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logging.root.setLevel(log_level) + logging.root.addHandler(handler) + + if msg_init is not None: + logger = logging.getLogger(__name__) + logger.info(msg_init) + + + +# TODO make inf_utils.py + + +# region Gradual Latent hires fix + + +class GradualLatent: + def __init__( + self, + ratio, + start_timesteps, + every_n_steps, + ratio_step, + s_noise=1.0, + gaussian_blur_ksize=None, + gaussian_blur_sigma=0.5, + gaussian_blur_strength=0.5, + unsharp_target_x=True, + ): + self.ratio = ratio + self.start_timesteps = start_timesteps + self.every_n_steps = every_n_steps + self.ratio_step = ratio_step + self.s_noise = s_noise + self.gaussian_blur_ksize = gaussian_blur_ksize + self.gaussian_blur_sigma = gaussian_blur_sigma + self.gaussian_blur_strength = gaussian_blur_strength + self.unsharp_target_x = unsharp_target_x + + def __str__(self) -> str: + return ( + f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" + ) + + def apply_unshark_mask(self, x: torch.Tensor): + if self.gaussian_blur_ksize is None: + return x + blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) + # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) + mask = (x - blurred) * self.gaussian_blur_strength + sharpened = x + mask + return sharpened + + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + self.gradual_latent = None + + def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): + self.resized_size = size + self.gradual_latent = gradual_latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + # logger.warning( + print( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + device = model_output.device + if self.resized_size is None: + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + s_noise = 1.0 + else: + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise + + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up * s_noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +# endregion diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 51f581b29..794659c94 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -2,10 +2,13 @@ import os import torch from safetensors.torch import load_file - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def main(file): - print(f"loading: {file}") + logger.info(f"loading: {file}") if os.path.splitext(file)[1] == ".safetensors": sd = load_file(file) else: diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 4ebfef7a4..c9377bee8 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -2,7 +2,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -125,7 +128,7 @@ def set_cond_image(self, cond_image): return # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -155,7 +158,7 @@ def forward(self, x): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している @@ -286,7 +289,7 @@ def create_modules( # create module instances self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) - print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): return x # dummy @@ -319,7 +322,7 @@ def load_weights(self, file): return info def apply_to(self): - print("applying LLLite for U-Net...") + logger.info("applying LLLite for U-Net...") for module in self.unet_modules: module.apply_to() self.add_module(module.lllite_name, module) @@ -374,19 +377,19 @@ def save_weights(self, file, dtype, metadata): # sdxl_original_unet.USE_REENTRANT = False # test shape etc - print("create unet") + logger.info("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create ControlNet-LLLite") + logger.info("create ControlNet-LLLite") control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") - print(control_net) + logger.info(control_net) - # print number of parameters - print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + # logger.info number of parameters + logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") input() @@ -398,12 +401,12 @@ def save_weights(self, file, dtype, metadata): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -414,12 +417,12 @@ def save_weights(self, file, dtype, metadata): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") batch_size = 1 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 @@ -439,7 +442,7 @@ def save_weights(self, file, dtype, metadata): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(f"{sample_param}") # from safetensors.torch import save_file diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 026880015..65b3520cf 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -6,7 +6,10 @@ from typing import Optional, List, Type import torch from library import sdxl_original_unet - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False @@ -270,7 +273,7 @@ def apply_to_modules( # create module instances self.lllite_modules = apply_to_modules(self, target_modules) - print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") + logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") # def prepare_optimizer_params(self): def prepare_params(self): @@ -281,8 +284,8 @@ def prepare_params(self): train_params.append(p) else: non_train_params.append(p) - print(f"count of trainable parameters: {len(train_params)}") - print(f"count of non-trainable parameters: {len(non_train_params)}") + logger.info(f"count of trainable parameters: {len(train_params)}") + logger.info(f"count of non-trainable parameters: {len(non_train_params)}") for p in non_train_params: p.requires_grad_(False) @@ -388,7 +391,7 @@ def load_lllite_weights(self, file, non_lllite_unet_sd=None): matches = pattern.findall(module_name) if matches is not None: for m in matches: - print(module_name, m) + logger.info(f"{module_name} {m}") module_name = module_name.replace(m, m.replace("_", "@")) module_name = module_name.replace("_", ".") module_name = module_name.replace("@", "_") @@ -407,7 +410,7 @@ def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kw def replace_unet_linear_and_conv2d(): - print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") + logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") sdxl_original_unet.torch.nn.Linear = LLLiteLinear sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d @@ -419,10 +422,10 @@ def replace_unet_linear_and_conv2d(): replace_unet_linear_and_conv2d() # test shape etc - print("create unet") + logger.info("create unet") unet = SdxlUNet2DConditionModelControlNetLLLite() - print("enable ControlNet-LLLite") + logger.info("enable ControlNet-LLLite") unet.apply_lllite(32, 64, None, False, 1.0) unet.to("cuda") # .to(torch.float16) @@ -439,14 +442,14 @@ def replace_unet_linear_and_conv2d(): # unet_sd[converted_key] = model_sd[key] # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) - # print(info) + # logger.info(info) - # print(unet) + # logger.info(unet) - # print number of parameters + # logger.info number of parameters params = unet.prepare_params() - print("number of parameters", sum(p.numel() for p in params)) - # print("type any key to continue") + logger.info(f"number of parameters {sum(p.numel() for p in params)}") + # logger.info("type any key to continue") # input() unet.set_use_memory_efficient_attention(True, False) @@ -455,12 +458,12 @@ def replace_unet_linear_and_conv2d(): # # visualize # import torchviz - # print("run visualize") + # logger.info("run visualize") # controlnet.set_control(conditioning_image) # output = unet(x, t, ctx, y) - # print("make_dot") + # logger.info("make_dot") # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # print("render") + # logger.info("render") # image.format = "svg" # "png" # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() @@ -471,13 +474,13 @@ def replace_unet_linear_and_conv2d(): scaler = torch.cuda.amp.GradScaler(enabled=True) - print("start training") + logger.info("start training") steps = 10 batch_size = 1 sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] for step in range(steps): - print(f"step {step}") + logger.info(f"step {step}") conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 x = torch.randn(batch_size, 4, 128, 128).cuda() @@ -494,9 +497,9 @@ def replace_unet_linear_and_conv2d(): scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) - print(sample_param) + logger.info(sample_param) # from safetensors.torch import save_file - # print("save weights") + # logger.info("save weights") # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py index e5a55d198..637f33450 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -12,10 +12,15 @@ import math import os import random -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import torch from torch import nn - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class DyLoRAModule(torch.nn.Module): """ @@ -165,7 +170,15 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -182,6 +195,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) + if unit is not None: unit = int(unit) else: @@ -223,7 +237,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -267,11 +281,11 @@ def __init__( self.apply_to_conv = apply_to_conv if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") if self.apply_to_conv: - print(f"apply LoRA to Conv2d with kernel size (3,3).") + logger.info("apply LoRA to Conv2d with kernel size (3,3).") # create module instances def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: @@ -306,9 +320,23 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + logger.info(f"create LoRA for Text Encoder {index}") + else: + index = None + logger.info("create LoRA for Text Encoder") + + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) - self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -316,7 +344,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras = create_modules(True, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -336,12 +364,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -359,12 +387,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -375,7 +403,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") """ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 0abee9836..1184cd8a5 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -10,7 +10,10 @@ from tqdm import tqdm from library import train_util, model_util import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name): if model_util.is_safetensors(file_name): @@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit): rank = value.size()[0] if rank > max_rank: max_rank = rank - print(f"Max rank: {max_rank}") + logger.info(f"Max rank: {max_rank}") rank = unit split_models = [] new_alpha = None while rank < max_rank: - print(f"Splitting rank {rank}") + logger.info(f"Splitting rank {rank}") new_sd = {} for key, value in lora_sd.items(): if "lora_down" in key: @@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit): # なぜかscaleするとおかしくなる…… # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] # scale = math.sqrt(this_rank / rank) # rank is > unit - # print(key, value.size(), this_rank, rank, value, scale) + # logger.info(key, value.size(), this_rank, rank, value, scale) # new_alpha = value * scale # always same # new_sd[key] = new_alpha new_sd[key] = value @@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit): def split(args): - print("loading Model...") + logger.info("loading Model...") lora_sd, metadata = load_state_dict(args.model) - print("Splitting Model...") + logger.info("Splitting Model...") original_rank, split_models = split_lora_model(lora_sd, args.unit) comment = metadata.get("ss_training_comment", "") @@ -94,7 +97,7 @@ def split(args): filename, ext = os.path.splitext(args.save_to) model_file_name = filename + f"-{new_rank:04d}{ext}" - print(f"saving model to: {model_file_name}") + logger.info(f"saving model to: {model_file_name}") save_to_file(model_file_name, state_dict, new_metadata) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 6357df55d..43c1d0058 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -11,7 +11,10 @@ from tqdm import tqdm from library import sai_model_spec, model_util, sdxl_model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) # CLAMP_QUANTILE = 0.99 # MIN_DIFF = 1e-1 @@ -43,6 +46,9 @@ def svd( clamp_quantile=0.99, min_diff=0.01, no_metadata=False, + load_precision=None, + load_original_model_to=None, + load_tuned_model_to=None, ): def str_to_dtype(p): if p == "float": @@ -57,28 +63,51 @@ def str_to_dtype(p): if v_parameterization is None: v_parameterization = v2 + load_dtype = str_to_dtype(load_precision) if load_precision else None save_dtype = str_to_dtype(save_precision) + work_device = "cpu" # load models if not sdxl: - print(f"loading original SD model : {model_org}") + logger.info(f"loading original SD model : {model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {model_tuned}") + if load_dtype is not None: + text_encoder_o = text_encoder_o.to(load_dtype) + unet_o = unet_o.to(load_dtype) + + logger.info(f"loading tuned SD model : {model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] + if load_dtype is not None: + text_encoder_t = text_encoder_t.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - print(f"loading original SDXL model : {model_org}") + device_org = load_original_model_to if load_original_model_to else "cpu" + device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" + + logger.info(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {model_tuned}") + if load_dtype is not None: + text_encoder_o1 = text_encoder_o1.to(load_dtype) + text_encoder_o2 = text_encoder_o2.to(load_dtype) + unet_o = unet_o.to(load_dtype) + + logger.info(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) text_encoders_t = [text_encoder_t1, text_encoder_t2] + if load_dtype is not None: + text_encoder_t1 = text_encoder_t1.to(load_dtype) + text_encoder_t2 = text_encoder_t2.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha @@ -100,38 +129,54 @@ def str_to_dtype(p): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) + + # clear weight to save memory + module_o.weight = None + module_t.weight = None # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") + logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - diff = diff.float() diffs[lora_name] = diff + # clear target Text Encoder to save memory + for text_encoder in text_encoders_t: + del text_encoder + if not text_encoder_different: - print("Text encoder is same. Extract U-Net only.") + logger.warning("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] - diffs = {} + diffs = {} # clear diffs for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - diff = diff.float() + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - if args.device: - diff = diff.to(args.device) + # clear weight to save memory + module_o.weight = None + module_t.weight = None diffs[lora_name] = diff + # clear LoRA network, target U-Net to save memory + del lora_network_o + del lora_network_t + del unet_t + # make LoRA with svd - print("calculating by svd") + logger.info("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): + if args.device: + mat = mat.to(args.device) + mat = mat.to(torch.float) # calc by float + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] @@ -143,7 +188,7 @@ def str_to_dtype(p): if device: mat = mat.to(device) - # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) + # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -171,8 +216,8 @@ def str_to_dtype(p): U = U.reshape(out_dim, rank, 1, 1) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - U = U.to("cpu").contiguous() - Vh = Vh.to("cpu").contiguous() + U = U.to(work_device, dtype=save_dtype).contiguous() + Vh = Vh.to(work_device, dtype=save_dtype).contiguous() lora_weights[lora_name] = (U, Vh) @@ -188,7 +233,7 @@ def str_to_dtype(p): lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) - print(f"Loading extracted LoRA weights: {info}") + logger.info(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): @@ -215,7 +260,7 @@ def str_to_dtype(p): metadata.update(sai_metadata) lora_network_save.save_weights(save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {save_to}") + logger.info(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: @@ -230,6 +275,13 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" ) + parser.add_argument( + "--load_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" + ) parser.add_argument( "--save_precision", type=str, @@ -285,6 +337,18 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--load_original_model_to", + type=str, + default=None, + help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) + parser.add_argument( + "--load_tuned_model_to", + type=str, + default=None, + help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) return parser diff --git a/networks/lora.py b/networks/lora.py index 0c75cd428..948b30b0e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -11,7 +11,12 @@ import numpy as np import torch import re +from library.utils import setup_logging +setup_logging() +import logging + +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -46,7 +51,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -177,7 +182,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -216,7 +221,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info(f"default_forward {self.lora_name} {x.size()}") return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -245,7 +250,8 @@ def get_mask_for_x(self, x): if mask is None: # raise ValueError(f"mask is None for resolution {area}") # emb_layers in SDXL doesn't have mask - # print(f"mask is None for resolution {area}, {x.size()}") + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: @@ -263,6 +269,8 @@ def regional_forward(self, x): lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) lx = lx * mask x = self.org_forward(x) @@ -291,7 +299,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") return query def sub_prompt_forward(self, x): @@ -306,7 +314,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -314,7 +322,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -332,7 +340,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -351,7 +359,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") # if num_sub_prompts > num of LoRAs, fill with zero for i in range(len(masks)): if masks[i] is None: @@ -374,7 +382,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") return out @@ -511,7 +519,9 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -520,7 +530,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -540,13 +550,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -586,7 +596,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -598,14 +608,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -613,24 +623,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -711,7 +721,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -786,20 +796,26 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -884,15 +900,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -900,15 +916,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -926,6 +942,10 @@ def set_multiplier(self, multiplier): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -939,12 +959,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -966,12 +986,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -982,7 +1002,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1113,7 +1133,7 @@ def set_region(self, sub_prompt_index, is_last_network, mask): for lora in self.text_encoder_loras + self.unet_loras: lora.set_network(self) - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): self.batch_size = batch_size self.num_sub_prompts = num_sub_prompts self.current_size = (height, width) @@ -1128,7 +1148,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m @@ -1139,6 +1159,13 @@ def resize_add(mh, mw): resize_add(h, w) if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + h = (h + 1) // 2 w = (w + 1) // 2 diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..b99b02442 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -9,8 +9,15 @@ import numpy as np from tqdm import tqdm from transformers import CLIPTextModel + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -248,7 +255,7 @@ def create_network_from_weights( elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(f"{lora_name} {value.size()} {dim}") # support old LoRA without alpha for key in modules_dim.keys(): @@ -291,12 +298,12 @@ def __init__( super().__init__() self.multiplier = multiplier - print(f"create LoRA network from weights") + logger.info("create LoRA network from weights") # convert SDXL Stability AI's U-Net modules to Diffusers converted = self.convert_unet_modules(modules_dim, modules_alpha) if converted: - print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") + logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") # create module instances def create_modules( @@ -331,7 +338,7 @@ def create_modules( lora_name = lora_name.replace(".", "_") if lora_name not in modules_dim: - # print(f"skipped {lora_name} (not found in modules_dim)") + # logger.info(f"skipped {lora_name} (not found in modules_dim)") skipped.append(lora_name) continue @@ -362,18 +369,18 @@ def create_modules( text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") if len(skipped_te) > 0: - print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") + logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") # extend U-Net target modules to include Conv2d 3x3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras: List[LoRAModule] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") if len(skipped_un) > 0: - print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") + logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") # assertion names = set() @@ -420,11 +427,11 @@ def set_multiplier(self, multiplier): def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") for lora in self.text_encoder_loras: lora.apply_to(multiplier) if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") for lora in self.unet_loras: lora.apply_to(multiplier) @@ -433,16 +440,16 @@ def unapply_to(self): lora.unapply_to() def merge_to(self, multiplier=1.0): - print("merge LoRA weights to original weights") + logger.info("merge LoRA weights to original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.merge_to(multiplier) - print(f"weights are merged") + logger.info(f"weights are merged") def restore_from(self, multiplier=1.0): - print("restore LoRA weights from original weights") + logger.info("restore LoRA weights from original weights") for lora in tqdm(self.text_encoder_loras + self.unet_loras): lora.restore_from(multiplier) - print(f"weights are restored") + logger.info(f"weights are restored") def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): # convert SDXL Stability AI's state dict to Diffusers' based state dict @@ -463,7 +470,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): my_state_dict = self.state_dict() for key in state_dict.keys(): if state_dict[key].size() != my_state_dict[key].size(): - # print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") + # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") state_dict[key] = state_dict[key].view(my_state_dict[key].size()) return super().load_state_dict(state_dict, strict) @@ -476,7 +483,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") @@ -490,7 +497,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): image_prefix = args.model_id.replace("/", "_") + "_" # load Diffusers model - print(f"load model from {args.model_id}") + logger.info(f"load model from {args.model_id}") pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] if args.sdxl: # use_safetensors=True does not work with 0.18.2 @@ -503,7 +510,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] # load LoRA weights - print(f"load LoRA weights from {args.lora_weights}") + logger.info(f"load LoRA weights from {args.lora_weights}") if os.path.splitext(args.lora_weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -512,10 +519,10 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): lora_sd = torch.load(args.lora_weights) # create by LoRA weights and load weights - print(f"create LoRA network") + logger.info(f"create LoRA network") lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) - print(f"load LoRA network weights") + logger.info(f"load LoRA network weights") lora_network.load_state_dict(lora_sd) lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this @@ -544,34 +551,34 @@ def seed_everything(seed): random.seed(seed) # create image with original weights - print(f"create image with original weights") + logger.info(f"create image with original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "original.png") # apply LoRA network to the model: slower than merge_to, but can be reverted easily - print(f"apply LoRA network to the model") + logger.info(f"apply LoRA network to the model") lora_network.apply_to(multiplier=1.0) - print(f"create image with applied LoRA") + logger.info(f"create image with applied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "applied_lora.png") # unapply LoRA network to the model - print(f"unapply LoRA network to the model") + logger.info(f"unapply LoRA network to the model") lora_network.unapply_to() - print(f"create image with unapplied LoRA") + logger.info(f"create image with unapplied LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unapplied_lora.png") # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) - print(f"merge LoRA network to the model") + logger.info(f"merge LoRA network to the model") lora_network.merge_to(multiplier=1.0) - print(f"create image with LoRA") + logger.info(f"create image with LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "merged_lora.png") @@ -579,31 +586,31 @@ def seed_everything(seed): # restore (unmerge) LoRA weights: numerically unstable # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない # 保存したstate_dictから元の重みを復元するのが確実 - print(f"restore (unmerge) LoRA weights") + logger.info(f"restore (unmerge) LoRA weights") lora_network.restore_from(multiplier=1.0) - print(f"create image without LoRA") + logger.info(f"create image without LoRA") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "unmerged_lora.png") # restore original weights - print(f"restore original weights") + logger.info(f"restore original weights") pipe.unet.load_state_dict(org_unet_sd) pipe.text_encoder.load_state_dict(org_text_encoder_sd) if args.sdxl: pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) - print(f"create image with restored original weights") + logger.info(f"create image with restored original weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "restore_original.png") # use convenience function to merge LoRA weights - print(f"merge LoRA weights with convenience function") + logger.info(f"merge LoRA weights with convenience function") merge_lora_weights(pipe, lora_sd, multiplier=1.0) - print(f"create image with merged LoRA weights") + logger.info(f"create image with merged LoRA weights") seed_everything(args.seed) image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py index a357d7f7f..919222ce8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -14,7 +14,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -49,7 +52,7 @@ def __init__( # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: - # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim @@ -197,7 +200,7 @@ def merge_to(self, sd, dtype, device): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module @@ -236,7 +239,7 @@ def set_region(self, region): self.region_mask = None def default_forward(self, x): - # print("default_forward", self.lora_name, x.size()) + # logger.info("default_forward", self.lora_name, x.size()) return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale def forward(self, x): @@ -278,7 +281,7 @@ def regional_forward(self, x): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) lx = lx * mask x = self.org_forward(x) @@ -307,7 +310,7 @@ def postp_to_q(self, x): if has_real_uncond: query[-self.network.batch_size :] = x[-self.network.batch_size :] - # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) + # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) return query def sub_prompt_forward(self, x): @@ -322,7 +325,7 @@ def sub_prompt_forward(self, x): lx = x[emb_idx :: self.network.num_sub_prompts] lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) + # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) x = self.org_forward(x) x[emb_idx :: self.network.num_sub_prompts] += lx @@ -330,7 +333,7 @@ def sub_prompt_forward(self, x): return x def to_out_forward(self, x): - # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) + # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) if self.network.is_last_network: masks = [None] * self.network.num_sub_prompts @@ -348,7 +351,7 @@ def to_out_forward(self, x): ) self.network.shared[self.lora_name] = (lx, masks) - # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) @@ -367,7 +370,7 @@ def to_out_forward(self, x): if has_real_uncond: out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) + # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # for i in range(len(masks)): # if masks[i] is None: # masks[i] = torch.zeros_like(masks[-1]) @@ -389,7 +392,7 @@ def to_out_forward(self, x): x1 = x1 + lx1 out[self.network.batch_size + i] = x1 - # print("to_out_forward", x.size(), out.size(), has_real_uncond) + # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond) return out @@ -526,7 +529,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -535,7 +538,7 @@ def parse_floats(s): len(block_alphas) == num_total_blocks ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" else: - print( + logger.warning( f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" ) block_alphas = [network_alpha] * num_total_blocks @@ -555,13 +558,13 @@ def parse_floats(s): else: if conv_alpha is None: conv_alpha = 1.0 - print( + logger.warning( f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" ) conv_block_alphas = [conv_alpha] * num_total_blocks else: if conv_dim is not None: - print( + logger.warning( f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" ) conv_block_dims = [conv_dim] * num_total_blocks @@ -601,7 +604,7 @@ def get_list(name_with_suffix) -> List[float]: elif name == "zeros": return [0.0 + base_lr] * max_len else: - print( + logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" % (name) ) @@ -613,14 +616,14 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = get_list(up_lr_weight) if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) up_lr_weight = up_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len] if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) @@ -628,24 +631,24 @@ def get_list(name_with_suffix) -> List[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") + logger.info("apply block learning rate / 階層別学習率を適用します。") if down_lr_weight != None: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: - print("down_lr_weight: all 1.0, すべて1.0") + logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", mid_lr_weight) + logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - print("mid_lr_weight: 1.0") + logger.info("mid_lr_weight: 1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: - print("up_lr_weight: all 1.0, すべて1.0") + logger.info("up_lr_weight: all 1.0, すべて1.0") return down_lr_weight, mid_lr_weight, up_lr_weight @@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh elif "lora_down" in key: dim = value.size()[0] modules_dim[lora_name] = dim - # print(lora_name, value.size(), dim) + # logger.info(lora_name, value.size(), dim) # support old LoRA without alpha for key in modules_dim.keys(): @@ -801,20 +804,20 @@ def __init__( self.module_dropout = module_dropout if modules_dim is not None: - print(f"create LoRA network from weights") + logger.info(f"create LoRA network from weights") elif block_dims is not None: - print(f"create LoRA network from block_dims") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") + logger.info(f"create LoRA network from block_dims") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -899,15 +902,15 @@ def create_modules( for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 - print(f"create LoRA for Text Encoder {index}:") + logger.info(f"create LoRA for Text Encoder {index}:") else: index = None - print(f"create LoRA for Text Encoder:") + logger.info(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE @@ -915,15 +918,15 @@ def create_modules( target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - print( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: - print(f"\t{name}") + logger.info(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -954,12 +957,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -981,12 +984,12 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_unet = True if apply_text_encoder: - print("enable LoRA for text encoder") + logger.info("enable LoRA for text encoder") else: self.text_encoder_loras = [] if apply_unet: - print("enable LoRA for U-Net") + logger.info("enable LoRA for U-Net") else: self.unet_loras = [] @@ -997,7 +1000,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - print(f"weights are merged") + logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( @@ -1144,7 +1147,7 @@ def set_current_generation(self, batch_size, num_sub_prompts, width, height, sha device = ref_weight.device def resize_add(mh, mw): - # print(mh, mw, mh * mw) + # logger.info(mh, mw, mh * mw) m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = m.to(device, dtype=dtype) mask_dic[mh * mw] = m diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..6aaa58107 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -5,27 +5,34 @@ import library.train_util as train_util import argparse from transformers import CLIPTokenizer + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() import library.model_util as model_util import lora +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = get_preferred_device() def interrogate(args): weights_dtype = torch.float16 # いろいろ準備する - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") args.pretrained_model_name_or_path = args.sd_model args.vae = None text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) - print(f"loading LoRA: {args.model}") + logger.info(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい @@ -35,11 +42,11 @@ def interrogate(args): has_te_weight = True break if not has_te_weight: - print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") + logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae - print("loading tokenizer") + logger.info("loading tokenizer") if args.v2: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: @@ -53,7 +60,7 @@ def interrogate(args): # トークンをひとつひとつ当たっていく token_id_start = 0 token_id_end = max(tokenizer.all_special_ids) - print(f"interrogate tokens are: {token_id_start} to {token_id_end}") + logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") def get_all_embeddings(text_encoder): embs = [] @@ -79,24 +86,24 @@ def get_all_embeddings(text_encoder): embs.extend(encoder_hidden_states) return torch.stack(embs) - print("get original text encoder embeddings.") + logger.info("get original text encoder embeddings.") orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) info = network.load_state_dict(weights_sd, strict=False) - print(f"Loading LoRA weights: {info}") + logger.info(f"Loading LoRA weights: {info}") network.to(DEVICE, dtype=weights_dtype) network.eval() del unet - print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") - print("get text encoder embeddings with lora.") + logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") + logger.info("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) # 比べる:とりあえず単純に差分の絶対値で - print("comparing...") + logger.info("comparing...") diffs = {} for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): diff = torch.mean(torch.abs(orig_emb - lora_emb)) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index 71492621e..fea8a3f32 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -7,7 +7,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "alpha" in key: continue @@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -239,7 +242,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) @@ -264,18 +267,18 @@ def str_to_dtype(p): ) if args.v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint( args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -289,12 +292,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py index ffd6b2b40..334d127b7 100644 --- a/networks/merge_lora_old.py +++ b/networks/merge_lora_old.py @@ -6,7 +6,10 @@ from safetensors.torch import load_file, save_file import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == '.safetensors': @@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = None dim = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in lora_sd.keys(): if 'alpha' in key: if key in merged_sd: @@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype): dim = lora_sd[key].size()[0] merged_sd[key] = lora_sd[key] * ratio - print(f"dim (rank): {dim}, alpha: {alpha}") + logger.info(f"dim (rank): {dim}, alpha: {alpha}") if alpha is None: alpha = dim @@ -142,19 +145,21 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"\nsaving SD model to: {args.save_to}") + logger.info("") + logger.info(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + logger.info(f"") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/oft.py b/networks/oft.py index 1d088f877..461a98698 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -8,7 +8,10 @@ import numpy as np import torch import re - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -237,7 +240,7 @@ def __init__( self.dim = dim self.alpha = alpha - print( + logger.info( f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" ) @@ -258,7 +261,7 @@ def create_modules( if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): oft_name = prefix + "." + name + "." + child_name oft_name = oft_name.replace(".", "_") - # print(oft_name) + # logger.info(oft_name) oft = module_class( oft_name, @@ -279,7 +282,7 @@ def create_modules( target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) - print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") # assertion names = set() @@ -316,7 +319,7 @@ def is_mergeable(self): # TODO refactor to common function with apply_to def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - print("enable OFT for U-Net") + logger.info("enable OFT for U-Net") for oft in self.unet_ofts: sd_for_lora = {} @@ -326,7 +329,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): oft.load_state_dict(sd_for_lora, False) oft.merge_to() - print(f"weights are merged") + logger.info(f"weights are merged") # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): @@ -338,11 +341,11 @@ def enumerate_params(ofts): for oft in ofts: params.extend(oft.parameters()) - # print num of params + # logger.info num of params num_params = 0 for p in params: num_params += p.numel() - print(f"OFT params: {num_params}") + logger.info(f"OFT params: {num_params}") return params param_data = {"params": enumerate_params(self.unet_ofts)} diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 03fc545e7..d697baa4c 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -2,80 +2,91 @@ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py # Thanks to cloneofsimo +import os import argparse import torch from safetensors.torch import load_file, save_file, safe_open from tqdm import tqdm -from library import train_util, model_util import numpy as np +from library import train_util +from library import model_util +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + MIN_SV = 1e-6 # Model save and load functions + def load_state_dict(file_name, dtype): - if model_util.is_safetensors(file_name): - sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() - else: - sd = torch.load(file_name, map_location='cpu') - metadata = None + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) - return sd, metadata + return sd, metadata -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) - if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) - else: - torch.save(model, file_name) + if model_util.is_safetensors(file_name): + save_file(state_dict, file_name, metadata) + else: + torch.save(state_dict, file_name) # Indexing functions + def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0)/original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S)-1)) + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) - return index + return index def index_sv_fro(S, target): - S_squared = S.pow(2) - s_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S)-1)) + S_squared = S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) - return index + return index def index_sv_ratio(S, target): - max_sv = S[0] - min_sv = max_sv/target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S)-1)) + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) - return index + return index # Modified from Kohaku-blueleaf's extract/merge functions def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size, kernel_size, _ = weight.size() U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] @@ -92,17 +103,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size = weight.size() - + U, S, Vh = torch.linalg.svd(weight.to(device)) - + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] - + U = U[:, :lora_rank] S = S[:lora_rank] U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() del U, S, Vh, weight @@ -113,7 +124,7 @@ def merge_conv(lora_down, lora_up, device): in_rank, in_size, kernel_size, k_ = lora_down.shape out_size, out_rank, _, _ = lora_up.shape assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) @@ -127,236 +138,274 @@ def merge_linear(lora_down, lora_up, device): in_rank, in_size = lora_down.shape out_size, out_rank = lora_up.shape assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" - + lora_down = lora_down.to(device) lora_up = lora_up.to(device) - + weight = lora_up @ lora_down del lora_up, lora_down return weight - + # Calculate new rank + def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} - if dynamic_method=="sv_ratio": + if dynamic_method == "sv_ratio": # Calculate new dim and alpha based off ratio new_rank = index_sv_ratio(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - elif dynamic_method=="sv_cumulative": + elif dynamic_method == "sv_cumulative": # Calculate new dim and alpha based off cumulative sum new_rank = index_sv_cumulative(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - elif dynamic_method=="sv_fro": + elif dynamic_method == "sv_fro": # Calculate new dim and alpha based off sqrt sum of squares new_rank = index_sv_fro(S, dynamic_param) + 1 - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) else: new_rank = rank - new_alpha = float(scale*new_rank) + new_alpha = float(scale * new_rank) - - if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 new_rank = 1 - new_alpha = float(scale*new_rank) - elif new_rank > rank: # cap max rank at rank + new_alpha = float(scale * new_rank) + elif new_rank > rank: # cap max rank at rank new_rank = rank - new_alpha = float(scale*new_rank) - + new_alpha = float(scale * new_rank) # Calculate resize info s_sum = torch.sum(torch.abs(S)) s_rank = torch.sum(torch.abs(S[:new_rank])) - + S_squared = S.pow(2) s_fro = torch.sqrt(torch.sum(S_squared)) s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro/s_fro) + fro_percent = float(s_red_fro / s_fro) param_dict["new_rank"] = new_rank param_dict["new_alpha"] = new_alpha - param_dict["sum_retained"] = (s_rank)/s_sum + param_dict["sum_retained"] = (s_rank) / s_sum param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0]/S[new_rank - 1] + param_dict["max_ratio"] = S[0] / S[new_rank - 1] return param_dict -def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): - network_alpha = None - network_dim = None - verbose_str = "\n" - fro_list = [] - - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim - - scale = network_alpha/network_dim - - if dynamic_method: - print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") - - lora_down_weight = None - lora_up_weight = None - - o_lora_sd = lora_sd.copy() - block_down_name = None - block_up_name = None - - with torch.no_grad(): - for key, value in tqdm(lora_sd.items()): - weight_name = None - if 'lora_down' in key: - block_down_name = key.rsplit('.lora_down', 1)[0] - weight_name = key.rsplit(".", 1)[-1] - lora_down_weight = value - else: - continue - - # find corresponding lora_up and alpha - block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + '.alpha', None) - - weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) - - if weights_loaded: - - conv2d = (len(lora_down_weight.size()) == 4) - if lora_alpha is None: - scale = 1.0 - else: - scale = lora_alpha/lora_down_weight.size()[0] - - if conv2d: - full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - else: - full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - - if verbose: - max_ratio = param_dict['max_ratio'] - sum_retained = param_dict['sum_retained'] - fro_retained = param_dict['fro_retained'] - if not np.isnan(fro_retained): - fro_list.append(float(fro_retained)) - - verbose_str+=f"{block_down_name:75} | " - verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" - - if verbose and dynamic_method: - verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" - else: - verbose_str+=f"\n" - - new_alpha = param_dict['new_alpha'] - o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype) - - block_down_name = None - block_up_name = None - lora_down_weight = None - lora_up_weight = None - weights_loaded = False - del param_dict - - if verbose: - print(verbose_str) - - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - print("resizing complete") - return o_lora_sd, network_dim, new_alpha +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and "alpha" in key: + network_alpha = value + if network_dim is None and "lora_down" in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha / network_dim + + if dynamic_method: + logger.info( + f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" + ) + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if "lora_down" in key: + block_down_name = key.rsplit(".lora_down", 1)[0] + weight_name = key.rsplit(".", 1)[-1] + lora_down_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + ".alpha", None) + + weights_loaded = lora_down_weight is not None and lora_up_weight is not None + + if weights_loaded: + + conv2d = len(lora_down_weight.size()) == 4 + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha / lora_down_weight.size()[0] + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict["max_ratio"] + sum_retained = param_dict["sum_retained"] + fro_retained = param_dict["fro_retained"] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str += f"{block_down_name:75} | " + verbose_str += ( + f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + ) + + if verbose and dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str += "\n" + + new_alpha = param_dict["new_alpha"] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + print(verbose_str) + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") + return o_lora_sd, network_dim, new_alpha def resize(args): - if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')): - raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") - - - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None - - if args.dynamic_method and not args.dynamic_param: - raise Exception("If using dynamic_method, then dynamic_param is required") - - merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - print("loading Model...") - lora_sd, metadata = load_state_dict(args.model, merge_dtype) - - print("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) - - # update metadata - if metadata is None: - metadata = {} - - comment = metadata.get("ss_training_comment", "") - - if not args.dynamic_method: - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) - else: - metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - metadata["ss_network_dim"] = 'Dynamic' - metadata["ss_network_alpha"] = 'Dynamic' + if args.save_to is None or not ( + args.save_to.endswith(".ckpt") + or args.save_to.endswith(".pt") + or args.save_to.endswith(".pth") + or args.save_to.endswith(".safetensors") + ): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + + args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + logger.info("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + logger.info("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model( + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + ) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})" + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = ( + f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + ) + metadata["ss_network_dim"] = "Dynamic" + metadata["ss_network_alpha"] = "Dynamic" - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash - print(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat") - parser.add_argument("--new_rank", type=int, default=4, - help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--model", type=str, default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument("--verbose", action="store_true", - help="Display verbose resizing information / rank変更時の詳細情報を出力する") - parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank") - parser.add_argument("--dynamic_param", type=float, default=None, - help="Specify target for dynamic reduction") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - resize(args) + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", + ) + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--new_conv_rank", + type=int, + default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" + ) + parser.add_argument( + "--dynamic_method", + type=str, + default=None, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", + ) + parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + resize(args) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index c513eb59f..3383a80de 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,7 +8,10 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": @@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ name_to_module[lora_name] = child_module for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") @@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # find original module for this lora module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" if module_name not in name_to_module: - print(f"no module found for LoRA weight: {key}") + logger.info(f"no module found for LoRA weight: {key}") continue module = name_to_module[module_name] - # print(f"apply {key} to {module}") + # logger.info(f"apply {key} to {module}") down_weight = lora_sd[key] up_weight = lora_sd[up_key] @@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ # W <- W + U * D weight = module.weight - # print(module_name, down_weight.size(), up_weight.size()) + # logger.info(module_name, down_weight.size(), up_weight.size()) if len(weight.size()) == 2: # linear weight = weight + ratio * (up_weight @ down_weight) * scale @@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # print(conved.size(), weight.size(), module.stride, module.padding) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + ratio * conved * scale module.weight = torch.nn.Parameter(weight) @@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): if lora_module_name not in base_alphas: base_alphas[lora_module_name] = alpha - print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(lora_sd.keys()): if "alpha" in key: continue @@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_up] = merged_sd[key_up][:,perm] - print("merged model") - print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") # check all dims are same dims_list = list(set(base_dims.values())) @@ -243,7 +246,7 @@ def str_to_dtype(p): save_dtype = merge_dtype if args.sd_model is not None: - print(f"loading SD model: {args.sd_model}") + logger.info(f"loading SD model: {args.sd_model}") ( text_model1, @@ -265,14 +268,14 @@ def str_to_dtype(p): None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from ) - print(f"saving SD model to: {args.save_to}") + logger.info(f"saving SD model to: {args.save_to}") sdxl_model_util.save_stable_diffusion_checkpoint( args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -286,7 +289,7 @@ def str_to_dtype(p): ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 16e813b36..cb00a6000 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -1,4 +1,3 @@ -import math import argparse import os import time @@ -8,7 +7,10 @@ from library import sai_model_spec, train_util import library.model_util as model_util import lora - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) CLAMP_QUANTILE = 0.99 @@ -41,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") + logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") merged_sd = {} v2 = None base_model = None for model, ratio in zip(models, ratios): - print(f"loading: {model}") + logger.info(f"loading: {model}") lora_sd, lora_metadata = load_state_dict(model, merge_dtype) if lora_metadata is not None: @@ -56,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) # merge - print(f"merging...") + logger.info(f"merging...") for key in tqdm(list(lora_sd.keys())): if "lora_down" not in key: continue @@ -73,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty out_dim = up_weight.size()[0] conv2d = len(down_weight.size()) == 4 kernel_size = None if not conv2d else down_weight.size()[2:4] - # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) + # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # make original weight if not exist if lora_module_name not in merged_sd: @@ -110,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty merged_sd[lora_module_name] = weight # extract from merged weights - print("extract new lora...") + logger.info("extract new lora...") merged_lora_sd = {} with torch.no_grad(): for lora_module_name, mat in tqdm(list(merged_sd.items())): @@ -188,7 +190,7 @@ def str_to_dtype(p): args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype ) - print(f"calculating hashes and creating metadata...") + logger.info(f"calculating hashes and creating metadata...") model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash @@ -203,12 +205,12 @@ def str_to_dtype(p): ) if v2: # TODO read sai modelspec - print( + logger.warning( "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" ) metadata.update(sai_metadata) - print(f"saving model to: {args.save_to}") + logger.info(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype, metadata) diff --git a/requirements.txt b/requirements.txt index 8517d95ac..51085744e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,11 +4,13 @@ diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 -einops==0.6.1 +einops==0.7.0 pytorch-lightning==1.9.0 -# bitsandbytes==0.39.1 -tensorboard==2.10.1 -safetensors==0.3.1 +bitsandbytes==0.43.0 +prodigyopt==1.0 +lion-pytorch==0.0.6 +tensorboard +safetensors==0.4.2 # gradio==3.16.2 altair==4.2.2 easygui==0.98.3 @@ -22,12 +24,17 @@ huggingface-hub==0.20.1 # for WD14 captioning (tensorflow) # tensorflow==2.10.1 # for WD14 captioning (onnx) -# onnx==1.14.1 -# onnxruntime-gpu==1.16.0 -# onnxruntime==1.16.0 +# onnx==1.15.0 +# onnxruntime-gpu==1.17.1 +# onnxruntime==1.17.1 +# for cuda 12.1(default 11.8) +# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + # this is for onnx: # protobuf==3.20.3 # open clip for SDXL -open-clip-torch==2.20.0 +# open-clip-torch==2.20.0 +# For logging +rich==13.7.0 # for kohya_ss library -e . diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab5399842..d52f85a8f 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -16,17 +16,11 @@ import diffusers import numpy as np -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, @@ -60,6 +54,13 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -81,12 +82,12 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: - print("Enable memory efficient attention for U-Net") + logger.info("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: - print("Enable xformers for U-Net") + logger.info("Enable xformers for U-Net") try: import xformers.ops except ImportError: @@ -94,7 +95,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio unet.set_use_memory_efficient_attention(True, False) elif sdpa: - print("Enable SDPA for U-Net") + logger.info("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) @@ -111,7 +112,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): @@ -167,7 +168,7 @@ def forward_flash_attn_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") + logger.info("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): @@ -223,7 +224,7 @@ def forward_xformers_0_14(self, hidden_states, **kwargs): def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") + logger.info("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states @@ -345,6 +346,8 @@ def __init__( self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + self.gradual_latent: GradualLatent = None + # Textual Inversion def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids @@ -357,7 +360,7 @@ def get_token_replacer(self, tokenizer): token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) + # logger.info("replace_tokens", tokens, "=>", token_replacements) if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() @@ -375,6 +378,14 @@ def replace_tokens(tokens): def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + @torch.no_grad() def __call__( self, @@ -449,7 +460,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance @@ -553,7 +564,7 @@ def __call__( text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) @@ -645,8 +656,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -709,7 +719,116 @@ def __call__( control_net.set_cond_image(None) each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + # # first, we downscale the latents to the half of the size + # # 最初に1/2に縮小する + # height, width = latents.shape[-2:] + # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( + # # latents.dtype + # # ) + # latents = latents[:, :, ::2, ::2] + # current_scale = 0.5 + + # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) + # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) + # scale_step = 0.125 + + # # timesteps at which to start increasing the scale: 1000 seems to be enough + # # 拡大を開始するtimesteps: 1000で十分そうである + # start_timesteps = 1000 + + # # how many steps to wait before increasing the scale again + # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) + # # large values leads to flat images + + # # 何ステップごとに拡大するか + # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) + # # 大きすぎると細部が書き込まれずのっぺりした感じになる + # every_n_steps = 5 + + # scale_step = input("scale step:") + # scale_step = float(scale_step) + # start_timesteps = input("start timesteps:") + # start_timesteps = int(start_timesteps) + # every_n_steps = input("every n steps:") + # every_n_steps = int(every_n_steps) + + # # for i, t in enumerate(tqdm(timesteps)): + # i = 0 + # last_step = 0 + # while i < len(timesteps): + # t = timesteps[i] + # print(f"[{i}] t={t}") + + # print(i, t, current_scale, latents.shape) + # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: + # if i == last_step: + # pass + # else: + # print("upscale") + # current_scale = min(current_scale + scale_step, 1.0) + + # h = int(height * current_scale) // 8 * 8 + # w = int(width * current_scale) // 8 * 8 + + # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( + # latents.dtype + # ) + # last_step = i + # i = max(0, i - every_n_steps + 1) + + # diff = timesteps[i] - timesteps[last_step] + # # resized_init_noise = torch.nn.functional.interpolate( + # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False + # # ).to(latents.dtype) + # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) + # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) + # # latents += torch.randn_like(latents) / 100 * diff + # continue + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -720,7 +839,7 @@ def __call__( if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): - print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") control_net.set_cond_image(None) each_control_net_enabled[j] = False @@ -778,6 +897,8 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None + i += 1 + if return_latents: return latents @@ -785,8 +906,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -801,8 +921,7 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if output_type == "pil": # image = self.numpy_to_pil(image) @@ -940,7 +1059,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") + logger.info(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: @@ -970,7 +1089,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L tokens.append(text_token) weights.append(text_weight) if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") return tokens, weights @@ -1243,7 +1362,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): elif len(count_range) == 2: count_range = [int(count_range[0]), int(count_range[1])] else: - print(f"invalid count range: {count_range}") + logger.warning(f"invalid count range: {count_range}") count_range = [1, 1] if count_range[0] > count_range[1]: count_range = [count_range[1], count_range[0]] @@ -1311,9 +1430,8 @@ def replacer(): # endregion - # def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # return text_encoder @@ -1328,6 +1446,7 @@ class BatchDataBase(NamedTuple): mask_image: Any clip_prompt: str guide_image: Any + raw_prompt: str class BatchDataExt(NamedTuple): @@ -1383,7 +1502,7 @@ def main(args): replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) # tokenizerを読み込む - print("loading tokenizer") + logger.info("loading tokenizer") tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) # schedulerを用意する @@ -1411,7 +1530,7 @@ def main(args): scheduler_module = diffusers.schedulers.scheduling_euler_discrete has_clip_sample = False elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_cls = EulerAncestralDiscreteSchedulerGL scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete has_clip_sample = False elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": @@ -1457,7 +1576,7 @@ def reset_sampler_noises(self, noises): self.sampler_noises = noises def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: @@ -1466,7 +1585,7 @@ def randn(self, shape, device=None, dtype=None, layout=None, generator=None): noise = None if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) self.sampler_noise_index += 1 @@ -1498,11 +1617,11 @@ def __getattr__(self, item): # ↓以下は結局PipeでFalseに設定されるので意味がなかった # # clip_sample=Trueにする # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") + # logger.info("set clip_sample to True") # scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -1527,7 +1646,7 @@ def __getattr__(self, item): vae_dtype = dtype if args.no_half_vae: - print("set vae_dtype to float32") + logger.info("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) vae.eval() @@ -1552,10 +1671,10 @@ def __getattr__(self, item): network_merge = args.network_merge_n_models else: network_merge = 0 - print(f"network_merge: {network_merge}") + logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] @@ -1573,7 +1692,7 @@ def __getattr__(self, item): raise ValueError("No weight. Weight is required.") network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + logger.info(f"load network weights from: {network_weight}") if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open @@ -1581,7 +1700,7 @@ def __getattr__(self, item): with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + logger.info(f"metadata for: {network_weight}: {metadata}") network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs @@ -1591,20 +1710,20 @@ def __getattr__(self, item): mergeable = network.is_mergeable() if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") + logger.warning("network is not mergiable. ignore merge option.") if not mergeable or i >= network_merge: # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") + logger.info(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) network.to(dtype).to(device) if network_pre_calc: - print("backup original weights") + logger.info("backup original weights") network.backup_weights() networks.append(network) @@ -1618,7 +1737,7 @@ def __getattr__(self, item): # upscalerの指定があれば取得する upscaler = None if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") imported_module = importlib.import_module(args.highres_fix_upscaler) us_kwargs = {} @@ -1627,7 +1746,7 @@ def __getattr__(self, item): key, value = net_arg.split("=") us_kwargs[key] = value - print("create upscaler") + logger.info("create upscaler") upscaler = imported_module.create_upscaler(**us_kwargs) upscaler.to(dtype).to(device) @@ -1644,7 +1763,7 @@ def __getattr__(self, item): # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) if args.control_net_lllite_models: for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") + logger.info(f"loading ControlNet-LLLite: {model_file}") from safetensors.torch import load_file @@ -1675,7 +1794,7 @@ def __getattr__(self, item): control_nets.append((control_net, ratio)) if args.opt_channels_last: - print(f"set optimizing: channels last") + logger.info(f"set optimizing: channels last") text_encoder1.to(memory_format=torch.channels_last) text_encoder2.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last) @@ -1699,7 +1818,7 @@ def __getattr__(self, item): args.clip_skip, ) pipe.set_control_nets(control_nets) - print("pipeline is ready.") + logger.info("pipeline is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -1708,6 +1827,29 @@ def __getattr__(self, item): if args.ds_depth_1 is not None: unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -1741,7 +1883,7 @@ def __getattr__(self, item): token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") assert ( min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 ), f"token ids1 is not ordered" @@ -1771,10 +1913,10 @@ def __getattr__(self, item): # promptを取得する if args.from_file is not None: - print(f"reading prompts from {args.from_file}") + logger.info(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] elif args.prompt is not None: prompt_list = [args.prompt] else: @@ -1800,7 +1942,7 @@ def load_images(path): for p in paths: image = Image.open(p) if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") + logger.info(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) @@ -1816,14 +1958,14 @@ def resize_images(imgs, size): return resized if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") + logger.info(f"load image for img2img: {args.image_path}") init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") + logger.info(f"loaded {len(init_images)} images for img2img") # CLIP Vision if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model.to(device, dtype) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) @@ -1831,22 +1973,22 @@ def resize_images(imgs, size): pipe.clip_vision_model = vision_model pipe.clip_vision_processor = processor pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") + logger.info(f"CLIP Vision model loaded.") else: init_images = None if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") + logger.info(f"load mask for inpainting: {args.mask_path}") mask_images = load_images(args.mask_path) assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") + logger.info(f"loaded {len(mask_images)} mask images for inpainting") else: mask_images = None # promptがないとき、画像のPngInfoから取得する if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' metadata") + logger.info("get prompts from images' metadata") for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] @@ -1875,17 +2017,17 @@ def resize_images(imgs, size): h = int(h * args.highres_fix_scale + 0.5) if init_images is not None: - print(f"resize img2img source images to {w}*{h}") + logger.info(f"resize img2img source images to {w}*{h}") init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") + logger.info(f"resize img2img mask images to {w}*{h}") mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 regional_network = True - print("use mask as region") + logger.info("use mask as region") size = None for i, network in enumerate(networks): @@ -1910,14 +2052,16 @@ def resize_images(imgs, size): prev_image = None # for VGG16 guided if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") guide_images = [] for p in args.guide_image_path: guide_images.extend(load_images(p)) - print(f"loaded {len(guide_images)} guide images for guidance") + logger.info(f"loaded {len(guide_images)} guide images for guidance") if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) guide_images = None else: guide_images = None @@ -1943,7 +2087,7 @@ def resize_images(imgs, size): max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 @@ -1955,7 +2099,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - print("process 1st stage") + logger.info("process 1st stage") batch_1st = [] for _, base, ext in batch: @@ -2000,7 +2144,7 @@ def scale_and_round(x): images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") + logger.info("process 2nd stage") width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height if upscaler: @@ -2046,7 +2190,7 @@ def scale_and_round(x): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), ( width, height, @@ -2068,6 +2212,7 @@ def scale_and_round(x): prompts = [] negative_prompts = [] + raw_prompts = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2098,11 +2243,16 @@ def scale_and_round(x): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) if init_image is not None: init_images.append(init_image) @@ -2166,7 +2316,7 @@ def scale_and_round(x): n.restore_weights() for n in networks: n.pre_calculation() - print("pre-calculation... done") + logger.info("pre-calculation... done") images = pipe( prompts, @@ -2200,8 +2350,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -2217,6 +2367,8 @@ def scale_and_round(x): metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) metadata.add_text("original-height", str(original_height)) metadata.add_text("original-width", str(original_width)) metadata.add_text("original-height-negative", str(original_height_negative)) @@ -2245,7 +2397,9 @@ def scale_and_round(x): cv2.waitKey() cv2.destroyAllWindows() except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) return images @@ -2258,7 +2412,8 @@ def scale_and_round(x): # interactive valid = False while not valid: - print("\nType prompt:") + logger.info("") + logger.info("Type prompt:") try: raw_prompt = input() except EOFError: @@ -2305,76 +2460,84 @@ def scale_and_round(x): ds_timesteps_2 = args.ds_timesteps_2 ds_ratio = args.ds_ratio + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: width = int(m.group(1)) - print(f"width: {width}") + logger.info(f"width: {width}") continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: height = int(m.group(1)) - print(f"height: {height}") + logger.info(f"height: {height}") continue m = re.match(r"ow (\d+)", parg, re.IGNORECASE) if m: original_width = int(m.group(1)) - print(f"original width: {original_width}") + logger.info(f"original width: {original_width}") continue m = re.match(r"oh (\d+)", parg, re.IGNORECASE) if m: original_height = int(m.group(1)) - print(f"original height: {original_height}") + logger.info(f"original height: {original_height}") continue m = re.match(r"nw (\d+)", parg, re.IGNORECASE) if m: original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") + logger.info(f"original width negative: {original_width_negative}") continue m = re.match(r"nh (\d+)", parg, re.IGNORECASE) if m: original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") + logger.info(f"original height negative: {original_height_negative}") continue m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") + logger.info(f"crop top: {crop_top}") continue m = re.match(r"cl (\d+)", parg, re.IGNORECASE) if m: crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") + logger.info(f"crop left: {crop_left}") continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") + logger.info(f"steps: {steps}") continue m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") + logger.info(f"seeds: {seeds}") continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale scale = float(m.group(1)) - print(f"scale: {scale}") + logger.info(f"scale: {scale}") continue m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) @@ -2383,25 +2546,25 @@ def scale_and_round(x): negative_scale = None else: negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") + logger.info(f"negative scale: {negative_scale}") continue m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) - print(f"strength: {strength}") + logger.info(f"strength: {strength}") continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") + logger.info(f"negative prompt: {negative_prompt}") continue m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") + logger.info(f"clip prompt: {clip_prompt}") continue m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) @@ -2409,47 +2572,131 @@ def scale_and_round(x): network_muls = [float(v) for v in m.group(1).split(",")] while len(network_muls) < len(networks): network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") + logger.info(f"network mul: {network_muls}") continue # Deep Shrink m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 1 ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") + logger.info(f"deep shrink depth 1: {ds_depth_1}") continue m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 1 ds_timesteps_1 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") continue m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink depth 2 ds_depth_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") + logger.info(f"deep shrink depth 2: {ds_depth_2}") continue m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink timesteps 2 ds_timesteps_2 = int(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") continue m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) if m: # deep shrink ratio ds_ratio = float(m.group(1)) ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2457,6 +2704,30 @@ def scale_and_round(x): ds_depth_1 = args.ds_depth_1 or 3 unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -2467,7 +2738,7 @@ def scale_and_round(x): if len(predefined_seeds) > 0: seed = predefined_seeds.pop(0) else: - print("predefined seeds are exhausted") + logger.error("predefined seeds are exhausted") seed = None elif args.iter_same_seed: seeds = iter_seed @@ -2477,7 +2748,7 @@ def scale_and_round(x): if seed is None: seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2493,7 +2764,7 @@ def scale_and_round(x): width = width - width % 32 height = height - height % 32 if width != init_image.size[0] or height != init_image.size[1]: - print( + logger.warning( f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" ) @@ -2518,7 +2789,9 @@ def scale_and_round(x): b1 = BatchData( False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), BatchDataExt( width, height, @@ -2553,18 +2826,25 @@ def scale_and_round(x): process_batch(batch_data, highres_fix) batch_data.clear() - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", ) parser.add_argument( "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" @@ -2576,7 +2856,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) parser.add_argument( "--use_original_file_name", action="store_true", @@ -2587,10 +2869,16 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument( - "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", ) parser.add_argument( - "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", ) parser.add_argument( "--original_height_negative", @@ -2604,8 +2892,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", ) - parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") - parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") parser.add_argument( "--vae_batch_size", @@ -2619,7 +2911,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", ) - parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument( "--sampler", @@ -2651,9 +2945,14 @@ def setup_parser() -> argparse.ArgumentParser: default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", ) parser.add_argument( "--tokenizer_cache_dir", @@ -2684,25 +2983,46 @@ def setup_parser() -> argparse.ArgumentParser: help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", ) parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", ) parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", ) parser.add_argument( "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument( - "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" ) - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", ) parser.add_argument( "--network_regional_mask_max_color_codes", @@ -2717,7 +3037,9 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) parser.add_argument( "--max_embeddings_multiples", type=int, @@ -2734,7 +3056,10 @@ def setup_parser() -> argparse.ArgumentParser: help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", ) parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", ) parser.add_argument( "--highres_fix_strength", @@ -2743,7 +3068,9 @@ def setup_parser() -> argparse.ArgumentParser: help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", ) parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", ) parser.add_argument( "--highres_fix_latents_upscaling", @@ -2751,7 +3078,10 @@ def setup_parser() -> argparse.ArgumentParser: help="use latents upscaling for highres fix / highres fixでlatentで拡大する", ) parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", ) parser.add_argument( "--highres_fix_upscaler_args", @@ -2766,11 +3096,18 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", ) parser.add_argument( - "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", ) # parser.add_argument( # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" @@ -2819,6 +3156,45 @@ def setup_parser() -> argparse.ArgumentParser: "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" ) + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) @@ -2830,4 +3206,5 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + setup_logging(args, reset=True) main(args) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd65..a1e93b7f0 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -8,23 +8,28 @@ import random from einops import repeat import numpy as np + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, get_preferred_device + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler from PIL import Image -import open_clip + +# import open_clip from safetensors.torch import load_file from library import model_util, sdxl_model_util import networks.lora as lora +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 @@ -87,7 +92,7 @@ def get_timestep_embedding(x, outdim): guidance_scale = 7 seed = None # 1 - DEVICE = "cuda" + DEVICE = get_preferred_device() DTYPE = torch.float16 # bfloat16 may work parser = argparse.ArgumentParser() @@ -142,7 +147,7 @@ def get_timestep_embedding(x, outdim): vae_dtype = DTYPE if DTYPE == torch.float16: - print("use float32 for vae") + logger.info("use float32 for vae") vae_dtype = torch.float32 vae.to(DEVICE, dtype=vae_dtype) vae.eval() @@ -153,12 +158,13 @@ def get_timestep_embedding(x, outdim): text_model2.eval() unet.set_use_memory_efficient_attention(True, False) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(True) # Tokenizers tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) - tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + # tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) + tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name) # LoRA for weights_file in args.lora_weights: @@ -189,9 +195,11 @@ def generate_image(prompt, prompt2, negative_prompt, seed=None): emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # print("emb1", emb1.shape) + # logger.info("emb1", emb1.shape) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) - uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right + uc_vector = c_vector.clone().to( + DEVICE, dtype=DTYPE + ) # ちょっとここ正しいかどうかわからない I'm not sure if this is right # crossattn @@ -214,13 +222,22 @@ def call_text_encoder(text, text2): # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい # text encoder 2 - with torch.no_grad(): - tokens = tokenizer2(text2).to(DEVICE) + # tokens = tokenizer2(text2).to(DEVICE) + tokens = tokenizer2( + text, + truncation=True, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(DEVICE) + with torch.no_grad(): enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] - # print("hidden_states2", text_embedding2_penu.shape) - text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion + # logger.info("hidden_states2", text_embedding2_penu.shape) + text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion # 連結して終了 concat and finish text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) @@ -228,7 +245,7 @@ def call_text_encoder(text, text2): # cond c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) - # print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) + # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) # uncond @@ -325,4 +342,4 @@ def call_text_encoder(text, text2): seed = int(seed) generate_image(prompt, prompt2, negative_prompt, seed) - print("Done!") + logger.info("Done!") diff --git a/sdxl_train.py b/sdxl_train.py index b4ce2770e..107bb9451 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1,7 +1,6 @@ # training with captions import argparse -import gc import math import os from multiprocessing import Value @@ -9,22 +8,24 @@ import toml from tqdm import tqdm -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + import library.config_util as config_util import library.sdxl_train_util as sdxl_train_util from library.config_util import ( @@ -96,8 +97,11 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) - assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" assert ( not args.train_text_encoder or not args.cache_text_encoder_outputs ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" @@ -122,18 +126,18 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -144,7 +148,7 @@ def train(args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -174,7 +178,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" ) return @@ -190,7 +194,7 @@ def train(args): ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -257,9 +261,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -352,8 +354,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -368,7 +370,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -412,8 +416,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -438,7 +441,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) # accelerator.print( # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" # ) @@ -458,7 +463,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -542,7 +547,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") # get size embeddings orig_size = batch["original_sizes_hw"] @@ -707,7 +712,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state: # and is_main_process: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -729,12 +734,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): logit_scale, ckpt_info, ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) @@ -757,7 +763,9 @@ def setup_parser() -> argparse.ArgumentParser: help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", ) - parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( "--no_half_vae", diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3cd..e99b4e35c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -2,7 +2,6 @@ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse -import gc import json import math import os @@ -13,14 +12,11 @@ import toml from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate @@ -47,6 +43,12 @@ apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -67,6 +69,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -80,11 +83,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -116,7 +119,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -126,7 +129,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -134,7 +139,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -166,9 +171,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -233,14 +236,14 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(unet.prepare_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -256,7 +259,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -293,8 +298,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -325,8 +329,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -343,7 +349,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -543,19 +549,20 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -571,8 +578,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377ba..dac56eedd 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -10,14 +9,11 @@ import toml from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel @@ -43,6 +39,12 @@ apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -63,6 +65,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) sdxl_train_util.verify_sdxl_training_args(args) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -76,11 +79,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -112,7 +115,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -122,7 +125,9 @@ def train(args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" else: - print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + logger.warning( + "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" + ) if args.cache_text_encoder_outputs: assert ( @@ -130,7 +135,7 @@ def train(args): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -165,9 +170,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -201,14 +204,14 @@ def train(args): accelerator.print("prepare optimizer, data loader etc.") trainable_params = list(network.prepare_optimizer_params()) - print(f"trainable params count: {len(trainable_params)}") - print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + logger.info(f"trainable params count: {len(trainable_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -224,7 +227,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -266,8 +271,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) @@ -298,8 +302,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -518,12 +524,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) @@ -539,8 +546,12 @@ def setup_parser() -> argparse.ArgumentParser: choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) - parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument( + "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" + ) + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") parser.add_argument( "--network_dropout", diff --git a/sdxl_train_network.py b/sdxl_train_network.py index a35779d00..d33239d92 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,18 +1,15 @@ import argparse -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): @@ -65,13 +62,12 @@ def cache_text_encoder_outputs_if_needed( if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす - print("move vae and unet to cpu to save memory") + logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -86,17 +82,16 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory_on_device(accelerator.device) if not args.lowram: - print("move vae and unet back to original device") + logger.info("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device) - text_encoders[1].to(accelerator.device) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: @@ -148,7 +143,7 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # print("text encoder outputs verified") + # logger.info("text encoder outputs verified") return encoder_hidden_states1, encoder_hidden_states2, pool2 diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f8a1d7bce..257d181ad 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -2,15 +2,11 @@ import os import regex + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass -import open_clip +from library.device_utils import init_ipex +init_ipex() + from library import sdxl_model_util, sdxl_train_util, train_util import train_textual_inversion diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 17916ef70..347db27f7 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -16,7 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) else: @@ -113,8 +116,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: # dataloaderを準備する train_dataset_group.set_caching_mode("latents") - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - print(f"Skipping {image_info.latents_npz} because it already exists.") + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 7d9b13d68..5f1d6d201 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -16,7 +16,10 @@ ConfigSanitizer, BlueprintGenerator, ) - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, _ = train_util.prepare_dtype(args) # モデルを読み込む - print("load model") + logger.info("load model") if args.sdxl: (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) text_encoders = [text_encoder1, text_encoder2] @@ -118,8 +121,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: # dataloaderを準備する train_dataset_group.set_caching_mode("text") - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: if args.skip_existing: if os.path.exists(image_info.text_encoder_outputs_npz): - print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") continue image_info.input_ids1 = input_ids1 diff --git a/tools/canny.py b/tools/canny.py index 5e0806898..f2190975c 100644 --- a/tools/canny.py +++ b/tools/canny.py @@ -1,6 +1,10 @@ import argparse import cv2 +import logging +from library.utils import setup_logging +setup_logging() +logger = logging.getLogger(__name__) def canny(args): img = cv2.imread(args.input) @@ -10,7 +14,7 @@ def canny(args): # canny_img = 255 - canny_img cv2.imwrite(args.output, canny_img) - print("done!") + logger.info("done!") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index fe30996aa..572ee2f0c 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -6,7 +6,10 @@ from diffusers import StableDiffusionPipeline import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def convert(args): # 引数を確認する @@ -30,7 +33,7 @@ def convert(args): # モデルを読み込む msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") + logger.info(f"loading {msg}: {args.model_to_load}") if is_load_ckpt: v2_model = args.v2 @@ -48,13 +51,13 @@ def convert(args): if args.v1 == args.v2: # 自動判定する v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ("v2" if v2_model else "v1")) + logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) else: v2_model = not args.v1 # 変換して保存する msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") + logger.info(f"converting and saving as {msg}: {args.model_to_save}") if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None @@ -70,15 +73,15 @@ def convert(args): save_dtype=save_dtype, vae=vae, ) - print(f"model saved. total converted state_dict keys: {key_count}") + logger.info(f"model saved. total converted state_dict keys: {key_count}") else: - print( + logger.info( f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" ) model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index 68dec6cae..bbc643edc 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,6 +15,10 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) KP_REYE = 11 KP_LEYE = 19 @@ -24,7 +28,7 @@ def detect_faces(detector, image, min_size): preds = detector(image) # bgr - # print(len(preds)) + # logger.info(len(preds)) faces = [] for pred in preds: @@ -78,7 +82,7 @@ def process(args): assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" # アニメ顔検出モデルを読み込む - print("loading face detector.") + logger.info("loading face detector.") detector = create_detector('yolov3') # cropの引数を解析する @@ -97,7 +101,7 @@ def process(args): crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] # 画像を処理する - print("processing.") + logger.info("processing.") output_extension = ".png" os.makedirs(args.dst_dir, exist_ok=True) @@ -111,7 +115,7 @@ def process(args): if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) if image.shape[2] == 4: - print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい h, w = image.shape[:2] @@ -144,11 +148,11 @@ def process(args): # 顔サイズを基準にリサイズする scale = args.resize_face_size / face_size if scale < cur_crop_width / w: - print( + logger.warning( f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_width / w if scale < cur_crop_height / h: - print( + logger.warning( f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") scale = cur_crop_height / h elif crop_h_ratio is not None: @@ -157,10 +161,10 @@ def process(args): else: # 切り出しサイズ指定あり if w < cur_crop_width: - print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_width / w if h < cur_crop_height: - print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") scale = cur_crop_height / h if args.resize_fit: scale = max(cur_crop_width / w, cur_crop_height / h) @@ -198,7 +202,7 @@ def process(args): face_img = face_img[y:y + cur_crop_height] # # debug - # print(path, cx, cy, angle) + # logger.info(path, cx, cy, angle) # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) # cv2.imshow("image", crp) # if cv2.waitKey() == 27: diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..f05cf7194 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -11,10 +11,16 @@ import numpy as np import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torch import nn from tqdm import tqdm from PIL import Image - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -216,7 +222,7 @@ def upscale( upsampled_images = upsampled_images / 127.5 - 1.0 # convert upsample images to latents with batch size - # print("Encoding upsampled (LANCZOS4) images...") + # logger.info("Encoding upsampled (LANCZOS4) images...") upsampled_latents = [] for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): batch = upsampled_images[i : i + vae_batch_size].to(vae.device) @@ -227,7 +233,7 @@ def upscale( upsampled_latents = torch.cat(upsampled_latents, dim=0) # upscale (refine) latents with this model with batch size - print("Upscaling latents...") + logger.info("Upscaling latents...") upscaled_latents = [] for i in range(0, upsampled_latents.shape[0], batch_size): with torch.no_grad(): @@ -242,7 +248,7 @@ def create_upscaler(**kwargs): weights = kwargs["weights"] model = Upscaler() - print(f"Loading weights from {weights}...") + logger.info(f"Loading weights from {weights}...") if os.path.splitext(weights)[1] == ".safetensors": from safetensors.torch import load_file @@ -255,20 +261,20 @@ def create_upscaler(**kwargs): # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEVICE = get_preferred_device() us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True) # load VAE with Diffusers assert args.vae_path is not None, "VAE path is required" - print(f"Loading VAE from {args.vae_path}...") + logger.info(f"Loading VAE from {args.vae_path}...") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.to(DEVICE, dtype=us_dtype) # prepare model - print("Preparing model...") + logger.info("Preparing model...") upscaler: Upscaler = create_upscaler(weights=args.weights) - # print("Loading weights from", args.weights) + # logger.info("Loading weights from", args.weights) # upscaler.load_state_dict(torch.load(args.weights)) upscaler.eval() upscaler.to(DEVICE, dtype=us_dtype) @@ -303,14 +309,14 @@ def upscale_images(args: argparse.Namespace): image_debug.save(dest_file_name) # upscale - print("Upscaling...") + logger.info("Upscaling...") upscaled_latents = upscaler.upscale( vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size ) upscaled_latents /= 0.18215 # decode with batch - print("Decoding...") + logger.info("Decoding...") upscaled_images = [] for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): with torch.no_grad(): diff --git a/tools/merge_models.py b/tools/merge_models.py index 391bfe677..8f1fbf2f8 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -5,7 +5,10 @@ from safetensors import safe_open from safetensors.torch import load_file, save_file from tqdm import tqdm - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def is_unet_key(key): # VAE or TextEncoder, the last one is for SDXL @@ -45,10 +48,10 @@ def merge(args): # check if all models are safetensors for model in args.models: if not model.endswith("safetensors"): - print(f"Model {model} is not a safetensors model") + logger.info(f"Model {model} is not a safetensors model") exit() if not os.path.isfile(model): - print(f"Model {model} does not exist") + logger.info(f"Model {model} does not exist") exit() assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" @@ -65,7 +68,7 @@ def merge(args): if merged_sd is None: # load first model - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") merged_sd = {} with safe_open(model, framework="pt", device=args.device) as f: for key in tqdm(f.keys()): @@ -81,11 +84,11 @@ def merge(args): value = ratio * value.to(dtype) # first model's value * ratio merged_sd[key] = value - print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) continue # load other models - print(f"Loading model {model}, ratio = {ratio}...") + logger.info(f"Loading model {model}, ratio = {ratio}...") with safe_open(model, framework="pt", device=args.device) as f: model_keys = f.keys() @@ -93,7 +96,7 @@ def merge(args): _, new_key = replace_text_encoder_key(key) if new_key not in merged_sd: if args.show_skipped and new_key not in first_model_keys: - print(f"Skip: {new_key}") + logger.info(f"Skip: {new_key}") continue value = f.get_tensor(key) @@ -104,7 +107,7 @@ def merge(args): for key in merged_sd.keys(): if key in model_keys: continue - print(f"Key {key} not in model {model}, use first model's value") + logger.warning(f"Key {key} not in model {model}, use first model's value") if key in supplementary_key_ratios: supplementary_key_ratios[key] += ratio else: @@ -112,7 +115,7 @@ def merge(args): # add supplementary keys' value (including VAE and TextEncoder) if len(supplementary_key_ratios) > 0: - print("add first model's value") + logger.info("add first model's value") with safe_open(args.models[0], framework="pt", device=args.device) as f: for key in tqdm(f.keys()): _, new_key = replace_text_encoder_key(key) @@ -120,7 +123,7 @@ def merge(args): continue if is_unet_key(new_key): # not VAE or TextEncoder - print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") value = f.get_tensor(key) # original key @@ -134,7 +137,7 @@ def merge(args): if not output_file.endswith(".safetensors"): output_file = output_file + ".safetensors" - print(f"Saving to {output_file}...") + logger.info(f"Saving to {output_file}...") # convert to save_dtype for k in merged_sd.keys(): @@ -142,7 +145,7 @@ def merge(args): save_file(merged_sd, output_file) - print("Done!") + logger.info("Done!") if __name__ == "__main__": diff --git a/tools/original_control_net.py b/tools/original_control_net.py index cd47bd76a..5640d542d 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -7,7 +7,10 @@ from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) class ControlNetInfo(NamedTuple): unet: Any @@ -51,7 +54,7 @@ def load_control_net(v2, unet, model): # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") + logger.info(f"ControlNet: loading control SD model : {model}") if model_util.is_safetensors(model): ctrl_sd_sd = load_file(model) @@ -61,7 +64,7 @@ def load_control_net(v2, unet, model): # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) + logger.info(f"ControlNet: loading difference: {is_difference}") # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # またTransfer Controlの元weightとなる @@ -89,13 +92,13 @@ def load_control_net(v2, unet, model): # ControlNetのU-Netを作成する ctrl_unet = UNet2DConditionModel(**unet_config) info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) + logger.info(f"ControlNet: loading Control U-Net: {info}") # U-Net以外のControlNetを作成する # TODO support middle only ctrl_net = ControlNet() info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) + logger.info("ControlNet: loading ControlNet: {info}") ctrl_unet.to(unet.device, dtype=unet.dtype) ctrl_net.to(unet.device, dtype=unet.dtype) @@ -117,7 +120,7 @@ def canny(img): return canny - print("Unsupported prep type:", prep_type) + logger.info(f"Unsupported prep type: {prep_type}") return None @@ -174,13 +177,26 @@ def call_unet_and_control_net( cnet_idx = step % cnet_cnt cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: return original_unet(sample, timestep, encoder_hidden_states) guided_hint = guided_hints[cnet_idx] + + # gradual latent support: match the size of guided_hint to the size of sample + if guided_hint.shape[-2:] != sample.shape[-2:]: + # print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}") + org_dtype = guided_hint.dtype + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(torch.float32) + guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic") + if org_dtype == torch.bfloat16: + guided_hint = guided_hint.to(org_dtype) + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net) + outs = unet_forward( + True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net + ) outs = [o * cnet_info.weight for o in outs] # U-Net @@ -192,7 +208,7 @@ def call_unet_and_control_net( # ControlNet cnet_outs_list = [] for i, cnet_info in enumerate(control_nets): - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) if cnet_info.ratio < current_ratio: continue guided_hint = guided_hints[i] @@ -232,7 +248,7 @@ def unet_forward( upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") + logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 1. time diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index 2d3224c4e..b8069fc1d 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,10 @@ import math from PIL import Image import numpy as np - +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): # Split the max_resolution string by "," and strip any whitespaces @@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi image.save(os.path.join(dst_img_folder, new_filename), quality=100) proc = "Resized" if current_pixels > max_pixels else "Saved" - print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") + logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") # If other files with same basename, copy them with resolution suffix if copy_associated_files: @@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi continue for max_resolution in max_resolutions: new_asoc_file = base + '+' + max_resolution + ext - print(f"Copy {asoc_file} as {new_asoc_file}") + logger.info(f"Copy {asoc_file} as {new_asoc_file}") shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) diff --git a/tools/show_metadata.py b/tools/show_metadata.py index 92ca7b1c8..05bfbe0a4 100644 --- a/tools/show_metadata.py +++ b/tools/show_metadata.py @@ -1,6 +1,10 @@ import json import argparse from safetensors import safe_open +from library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) @@ -10,10 +14,10 @@ metadata = f.metadata() if metadata is None: - print("No metadata found") + logger.error("No metadata found") else: # metadata is json dict, but not pretty printed # sort by key and pretty print print(json.dumps(metadata, indent=4, sort_keys=True)) - \ No newline at end of file + diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab7a..e44f08853 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -10,17 +9,11 @@ import toml from tqdm import tqdm -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel @@ -40,6 +33,12 @@ pyramid_noise_like, apply_noise_offset, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # TODO 他のスクリプトと共通化する @@ -61,6 +60,7 @@ def train(args): # training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_user_config = args.dataset_config is not None @@ -74,11 +74,11 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "conditioning_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -108,7 +108,7 @@ def train(args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -119,7 +119,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -224,10 +224,8 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -241,8 +239,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -258,7 +256,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -314,8 +314,10 @@ def train(args): accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -337,7 +339,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -563,7 +565,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく @@ -572,12 +574,13 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, controlnet, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) diff --git a/train_db.py b/train_db.py index 14d9dff13..41a9a7b99 100644 --- a/train_db.py +++ b/train_db.py @@ -1,7 +1,6 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc import argparse import itertools import math @@ -10,17 +9,11 @@ import toml from tqdm import tqdm -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler @@ -40,6 +33,12 @@ scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) # perlin_noise, @@ -47,6 +46,7 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -59,11 +59,11 @@ def train(args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -98,13 +98,13 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") if args.gradient_accumulation_steps > 1: - print( + logger.warning( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" ) - print( + logger.warning( f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) @@ -143,9 +143,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -182,8 +180,8 @@ def train(args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -198,7 +196,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -269,7 +269,7 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) @@ -444,7 +444,7 @@ def train(args): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -454,12 +454,13 @@ def train(args): train_util.save_sd_model_on_train_end( args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae ) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) diff --git a/train_network.py b/train_network.py index a75299cda..9e573d9f6 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import sys @@ -11,18 +10,13 @@ import toml from tqdm import tqdm -import torch -from torch.nn.parallel import DistributedDataParallel as DDP -try: - import intel_extension_for_pytorch as ipex +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - if torch.xpu.is_available(): - from library.ipex import ipex_init +from torch.nn.parallel import DistributedDataParallel as DDP - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util @@ -46,6 +40,12 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) class NetworkTrainer: @@ -117,7 +117,7 @@ def cache_text_encoder_outputs_if_needed( self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): for t_enc in text_encoders: - t_enc.to(accelerator.device) + t_enc.to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): input_ids = batch["input_ids"].to(accelerator.device) @@ -141,6 +141,7 @@ def train(self, args): training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None @@ -158,18 +159,18 @@ def train(self, args): if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") + logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.warning( "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Using DreamBooth method.") + logger.info("Using DreamBooth method.") user_config = { "datasets": [ { @@ -180,7 +181,7 @@ def train(self, args): ] } else: - print("Training with captions.") + logger.info("Training with captions.") user_config = { "datasets": [ { @@ -209,7 +210,7 @@ def train(self, args): train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: - print( + logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return @@ -222,7 +223,7 @@ def train(self, args): self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する - print("preparing accelerator") + logger.info("preparing accelerator") accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -271,13 +272,12 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu self.cache_text_encoder_outputs_if_needed( args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype ) @@ -309,11 +309,12 @@ def train(self, args): ) if network is None: return + network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( + logger.warning( "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" ) args.scale_weight_norms = False @@ -348,8 +349,8 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -389,17 +390,33 @@ def train(self, args): accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != "no" + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) - # acceleratorがなんかよろしくやってくれるらしい - # TODO めちゃくちゃ冗長なのでコードを整理する + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: if len(text_encoders) > 1: text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] @@ -407,8 +424,8 @@ def train(self, args): text_encoder = accelerator.prepare(text_encoder) text_encoders = [text_encoder] else: - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: @@ -421,9 +438,6 @@ def train(self, args): if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -550,6 +564,11 @@ def train(self, args): "random_crop": bool(subset.random_crop), "shuffle_caption": bool(subset.shuffle_caption), "keep_tokens": subset.keep_tokens, + "keep_tokens_separator": subset.keep_tokens_separator, + "secondary_separator": subset.secondary_separator, + "enable_wildcard": bool(subset.enable_wildcard), + "caption_prefix": subset.caption_prefix, + "caption_suffix": subset.caption_suffix, } image_dir_or_metadata_file = None @@ -685,7 +704,7 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -754,7 +773,17 @@ def remove_model(old_ckpt_name): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor - b_size = latents.shape[0] + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + accelerator.unwrap_model(network).set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning @@ -778,10 +807,24 @@ def remove_model(old_ckpt_name): args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: @@ -808,10 +851,11 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - self.all_reduce_network(accelerator, network) # sync DDP grad manually - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(network).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -896,19 +940,20 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if is_main_process and args.save_state: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) @@ -916,7 +961,9 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) - parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) parser.add_argument( "--save_model_as", type=str, @@ -928,10 +975,17 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") parser.add_argument( - "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", ) parser.add_argument( "--network_alpha", @@ -946,14 +1000,25 @@ def setup_parser() -> argparse.ArgumentParser: help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", ) parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" ) - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument( - "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", ) parser.add_argument( - "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", ) parser.add_argument( "--dim_from_weights", diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1d..0266bc143 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,22 +1,15 @@ import argparse -import gc import math import os from multiprocessing import Value import toml from tqdm import tqdm -import torch - -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +import torch +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer @@ -37,6 +30,12 @@ add_v_prediction_like_loss, apply_debiased_estimation, ) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -173,6 +172,7 @@ def train(self, args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) cache_latents = args.cache_latents @@ -183,7 +183,7 @@ def train(self, args): tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -293,7 +293,7 @@ def train(self, args): ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -368,9 +368,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -387,8 +385,8 @@ def train(self, args): _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -505,7 +503,7 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -730,24 +728,24 @@ def remove_model(old_ckpt_name): is_main_process = accelerator.is_main_process if is_main_process: text_encoder = accelerator.unwrap_model(text_encoder) + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -763,7 +761,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -773,7 +773,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b43549d..ad7c267eb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -1,20 +1,16 @@ import importlib import argparse -import gc import math import os import toml from multiprocessing import Value from tqdm import tqdm + import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.device_utils import init_ipex, clean_memory_on_device +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler @@ -38,6 +34,12 @@ ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) imagenet_templates_small = [ "a photo of a {}", @@ -96,12 +98,13 @@ def train(args): if args.output_name is None: args.output_name = args.token_string use_template = args.use_object_template or args.use_style_template + setup_logging(args, reset=True) train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: - print( + logger.warning( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) assert ( @@ -116,7 +119,7 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # acceleratorを準備する - print("prepare accelerator") + logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする @@ -129,7 +132,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + logger.warning( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -143,7 +146,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + logger.info(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -171,7 +174,7 @@ def train(args): tokenizer.add_tokens(token_strings_XTI) token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - print(f"tokens are added (XTI): {token_ids_XTI}") + logger.info(f"tokens are added (XTI): {token_ids_XTI}") # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -180,7 +183,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids_XTI): token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -188,22 +191,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # logger.info(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids_XTI, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + logger.info(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + logger.info( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -211,14 +214,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + logger.info("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + logger.info("Train with captions.") user_config = { "datasets": [ { @@ -242,7 +245,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print(f"use template for training captions. is object: {args.use_object_template}") + logger.info(f"use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -266,7 +269,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -288,9 +291,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -299,13 +300,13 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + logger.info("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -320,7 +321,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + logger.info( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -334,7 +337,7 @@ def train(args): ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # logger.info(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -372,15 +375,17 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + logger.info("running training / 学習開始") + logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + logger.info(f" num epochs / epoch数: {num_train_epochs}") + logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") + logger.info( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -395,17 +400,20 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) # function for saving/removing def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -413,12 +421,13 @@ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + logger.info(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + logger.info("") + logger.info(f"epoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train() @@ -577,7 +586,7 @@ def remove_model(old_ckpt_name): accelerator.end_training() - if args.save_state and is_main_process: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() @@ -588,7 +597,7 @@ def remove_model(old_ckpt_name): ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - print("model saved.") + logger.info("model saved.") def save_weights(file, updated_embs, save_dtype): @@ -649,6 +658,7 @@ def load_weights(file): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() + add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) @@ -664,7 +674,9 @@ def setup_parser() -> argparse.ArgumentParser: help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", ) - parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" + ) parser.add_argument( "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" ) @@ -674,7 +686,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", ) - parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" + ) parser.add_argument( "--use_object_template", action="store_true",