Skip to content

Commit

Permalink
add detail dataset config feature by extra config file (kohya-ss#227)
Browse files Browse the repository at this point in the history
* add config file schema

* change config file specification

* refactor config utility

* unify batch_size to train_batch_size

* fix indent size

* use batch_size instead of train_batch_size

* make cache_latents configurable on subset

* rename options
* bucket_repo_range
* shuffle_keep_tokens

* update readme

* revert to min_bucket_reso & max_bucket_reso

* use subset structure in dataset

* format import lines

* split mode specific options

* use only valid subset

* change valid subsets name

* manage multiple datasets by dataset group

* update config file sanitizer

* prune redundant validation

* add comments

* update type annotation

* rename json_file_name to metadata_file

* ignore when image dir is invalid

* fix tag shuffle and dropout

* ignore duplicated subset

* add method to check latent cachability

* fix format

* fix bug

* update caption dropout default values

* update annotation

* fix bug

* add option to enable bucket shuffle across dataset

* update blueprint generate function

* use blueprint generator for dataset initialization

* delete duplicated function

* update config readme

* delete debug print

* print dataset and subset info as info

* enable bucket_shuffle_across_dataset option

* update config readme for clarification

* compensate quotes for string option example

* fix bug of bad usage of join

* conserve trained metadata backward compatibility

* enable shuffle in data loader by default

* delete resolved TODO

* add comment for image data handling

* fix reference bug

* fix undefined variable bug

* prevent raise overwriting

* assert image_dir and metadata_file validity

* add debug message for ignoring subset

* fix inconsistent import statement

* loosen too strict validation on float value

* sanitize argument parser separately

* make image_dir optional for fine tuning dataset

* fix import

* fix trailing characters in print

* parse flexible dataset config deterministically

* use relative import

* print supplementary message for parsing error

* add note about different methods

* add note of benefit of separate dataset

* add error example

* add note for english readme plan

---------

Co-authored-by: Kohya S <[email protected]>
  • Loading branch information
fur0ut0 and kohya-ss authored Mar 1, 2023
1 parent 8270765 commit 8abb864
Show file tree
Hide file tree
Showing 8 changed files with 1,369 additions and 320 deletions.
279 changes: 279 additions & 0 deletions config_README-ja.md

Large diffs are not rendered by default.

52 changes: 34 additions & 18 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from diffusers import DDPMScheduler

import library.train_util as train_util

import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)

def collate_fn(examples):
return examples[0]
Expand All @@ -30,25 +34,36 @@ def train(args):

tokenizer = train_util.load_tokenizer(args)

train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)

# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)

train_dataset.make_buckets()
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.config_file is not None:
print(f"Load config file from {args.config_file}")
user_config = config_util.load_user_config(args.config_file)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
else:
user_config = {
"datasets": [{
"subsets": [{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}]
}]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

if args.debug_dataset:
train_util.debug_dataset(train_dataset)
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset) == 0:
if len(train_dataset_group) == 0:
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
return

if cache_latents:
assert train_dataset_group.is_latent_cachable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
Expand Down Expand Up @@ -109,7 +124,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset.cache_latents(vae)
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
Expand Down Expand Up @@ -155,7 +170,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
Expand Down Expand Up @@ -199,7 +214,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
Expand All @@ -218,7 +233,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
train_dataset_group.set_current_epoch(epoch + 1)

for m in training_models:
m.train()
Expand Down Expand Up @@ -340,6 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)

parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers / Diffusersでxformersを使用する')
Expand Down
Loading

0 comments on commit 8abb864

Please sign in to comment.