diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..1bf7ed955 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -85,6 +85,8 @@ class BaseDatasetParams: max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -200,6 +202,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } @@ -427,64 +431,89 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) - - if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + + # print info + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") - - print(info) + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) # make buckets first because it determines the length of dataset # and set the same seed for all datasets @@ -494,7 +523,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 40bf8474e..8e29aeda0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -122,6 +122,22 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] + + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: self.image_key: str = image_key @@ -1306,6 +1322,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1316,12 +1333,18 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset, ) -> None: super().__init__(tokenizer, max_token_length, resolution, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed + self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1374,6 +1397,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") + + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index a5362ddb6..bdf633d5f 100644 --- a/train_network.py +++ b/train_network.py @@ -185,10 +185,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -208,6 +209,10 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" self.assert_extra_args(args, train_dataset_group) @@ -260,6 +265,9 @@ def train(self, args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -341,61 +349,8 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - def get_indices_without_reg(dataset: torch.utils.data.Dataset): - return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] - - from typing import Sequence, Union - from torch._utils import _accumulate - import warnings - from torch.utils.data.dataset import Subset - - def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): - indices = get_indices_without_reg(dataset) - random.shuffle(indices) - - subset_lengths = [] - - for i, frac in enumerate(lengths): - if frac < 0 or frac > 1: - raise ValueError(f"Fraction at index {i} is not between 0 and 1") - n_items_in_split = int(math.floor(len(indices) * frac)) - subset_lengths.append(n_items_in_split) - - remainder = len(indices) - sum(subset_lengths) - - for i in range(remainder): - idx_to_add_at = i % len(subset_lengths) - subset_lengths[idx_to_add_at] += 1 - - lengths = subset_lengths - for i, length in enumerate(lengths): - if length == 0: - warnings.warn(f"Length of split at index {i} is 0. " - f"This might result in an empty dataset.") - - if sum(lengths) != len(indices): - raise ValueError("Sum of input lengths does not equal the length of the input dataset!") - - return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] - - - if args.validation_ratio > 0.0: - train_ratio = 1 - args.validation_ratio - validation_ratio = args.validation_ratio - train, val = random_split( - train_dataset_group, - [train_ratio, validation_ratio] - ) - print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") - print(f"train images: {len(train)}, validation images: {len(val)}") - else: - train = train_dataset_group - val = [] - - - train_dataloader = torch.utils.data.DataLoader( - train, + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, @@ -404,7 +359,7 @@ def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, ) val_dataloader = torch.utils.data.DataLoader( - val, + val_dataset_group if val_dataset_group is not None else [], shuffle=False, batch_size=1, collate_fn=collator,