Skip to content

Commit

Permalink
Merge branch 'pr/1250' into master3
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Apr 9, 2024
2 parents aa80cc8 + 6116b62 commit e0c597e
Show file tree
Hide file tree
Showing 16 changed files with 440 additions and 284 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/typos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4

- name: typos-action
uses: crate-ci/typos@v1.17.2
uses: crate-ci/typos@v1.19.0
44 changes: 44 additions & 0 deletions README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)

[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause

## その他の情報

### LoRAの名称について

`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。

1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)

Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA

2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)

1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA

デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args``conv_dim` を指定してください。

<!--
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
-->

### 学習中のサンプル画像生成

プロンプトファイルは例えば以下のようになります。

```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```

`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。

* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.

`( )``[ ]` などの重みづけも動作します。
323 changes: 107 additions & 216 deletions README.md

Large diffs are not rendered by default.

33 changes: 25 additions & 8 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("scheduleFree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -335,6 +348,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for m in training_models:
m.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
Expand Down Expand Up @@ -405,7 +420,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)


Expand Down Expand Up @@ -536,6 +552,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)
128 changes: 116 additions & 12 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,7 @@ def __init__(
subset.image_dir,
False,
None,
subset.caption_extension,
subset.caption_extension,
subset.cache_info,
subset.num_repeats,
subset.shuffle_caption,
Expand Down Expand Up @@ -3235,7 +3235,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
)
parser.add_argument(
"--gradient_accumulation_steps",
Expand Down Expand Up @@ -3403,20 +3403,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss to use and whether it's scheduled based on the timestep"
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
default="snr",
choices=["constant", "exponential", "snr"],
help="The type of loss to use and whether it's scheduled based on the timestep"
help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"
+ " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトは snr",
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -3615,6 +3616,60 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
)


# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):
# if wandb is enabled, the command line is exposed to the public
# check whether sensitive options are included in the command line arguments
# if so, warn or inform the user to move them to the configuration file
# wandbが有効な場合、コマンドラインが公開される
# 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、
# 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する

wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb"
if not wandb_enabled:
return

sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
"vae",
"tokenizer_cache_dir",
"train_data_dir",
"conditioning_data_dir",
"reg_data_dir",
"output_dir",
"logging_dir",
]

for arg in sensitive_args:
if getattr(args, arg, None) is not None:
logger.warning(
f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
)

# if path is absolute, it may include sensitive information
for arg in sensitive_path_args:
if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)):
logger.info(
f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path."
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。"
)

if getattr(args, "config_file", None) is not None:
logger.info(
f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path."
+ f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。"
)

# other sensitive options
if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public":
logger.info(
f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
+ f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
)


def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
Expand Down Expand Up @@ -4283,6 +4338,21 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
Expand Down Expand Up @@ -5181,6 +5251,38 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler,

return timesteps, huber_c

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):

# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way

if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timestep = timesteps.item()

if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long()

return timesteps, huber_c


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
Expand Down Expand Up @@ -5216,24 +5318,26 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
return noise, noisy_latents, timesteps, huber_c

# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):

if loss_type == 'l2':
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
):

if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber':
elif loss_type == "huber":
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1':
elif loss_type == "smooth_l1":
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
return loss

def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
Expand Down
5 changes: 2 additions & 3 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,13 @@ def get_mask_for_x(self, x):
area = x.size()[1]

mask = self.network.mask_dic.get(area, None)
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
if mask is None or len(x.size()) == 2:
# emb_layers in SDXL doesn't have mask
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1))
return mask

Expand Down
Loading

0 comments on commit e0c597e

Please sign in to comment.