Skip to content

Commit

Permalink
Add validation split of datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Nov 5, 2023
1 parent 69cc525 commit e2b819f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 110 deletions.
145 changes: 91 additions & 54 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class BaseDatasetParams:
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0

@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
Expand Down Expand Up @@ -200,6 +202,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
}

Expand Down Expand Up @@ -427,64 +431,89 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)

# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")

if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ")

if is_dreambooth:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

# print info
def print_info(_datasets):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")

if dataset.enable_bucket:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")

print(info)
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ")

if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")

print(info)

print_info(datasets)

if len(val_datasets) > 0:
print("Validation dataset")
print_info(val_datasets)

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
Expand All @@ -494,7 +523,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset.make_buckets()
dataset.set_seed(seed)

return DatasetGroup(datasets)
for i, dataset in enumerate(val_datasets):
print(f"[Validation Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)


def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
Expand Down
26 changes: 26 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,22 @@
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"


def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)

if is_train:
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
return paths[len(paths) - round(len(paths) * validation_split):]


class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
Expand Down Expand Up @@ -1306,6 +1322,7 @@ class DreamBoothDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
is_train: bool,
batch_size: int,
tokenizer,
max_token_length,
Expand All @@ -1316,12 +1333,18 @@ def __init__(
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
validation_split: float,
validation_seed: Optional[int],
debug_dataset,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)

assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed

self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
Expand Down Expand Up @@ -1374,6 +1397,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
return [], []

img_paths = glob_images(subset.image_dir, "*")

if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
Expand Down
67 changes: 11 additions & 56 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,11 @@ def train(self, args):
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand All @@ -208,6 +209,10 @@ def train(self, args):
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if val_dataset_group is not None:
assert (
val_dataset_group.is_latent_cacheable()
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

self.assert_extra_args(args, train_dataset_group)

Expand Down Expand Up @@ -260,6 +265,9 @@ def train(self, args):
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
if val_dataset_group is not None:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
Expand Down Expand Up @@ -341,61 +349,8 @@ def train(self, args):
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで

def get_indices_without_reg(dataset: torch.utils.data.Dataset):
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]

from typing import Sequence, Union
from torch._utils import _accumulate
import warnings
from torch.utils.data.dataset import Subset

def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
indices = get_indices_without_reg(dataset)
random.shuffle(indices)

subset_lengths = []

for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(math.floor(len(indices) * frac))
subset_lengths.append(n_items_in_split)

remainder = len(indices) - sum(subset_lengths)

for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1

lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")

if sum(lengths) != len(indices):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]


if args.validation_ratio > 0.0:
train_ratio = 1 - args.validation_ratio
validation_ratio = args.validation_ratio
train, val = random_split(
train_dataset_group,
[train_ratio, validation_ratio]
)
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
print(f"train images: {len(train)}, validation images: {len(val)}")
else:
train = train_dataset_group
val = []



train_dataloader = torch.utils.data.DataLoader(
train,
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
Expand All @@ -404,7 +359,7 @@ def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int,
)

val_dataloader = torch.utils.data.DataLoader(
val,
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,
Expand Down

0 comments on commit e2b819f

Please sign in to comment.