Skip to content

Commit

Permalink
Merge branch 'val2'
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Mar 9, 2024
2 parents e7e5061 + 2601314 commit 98b43e9
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 63 deletions.
124 changes: 70 additions & 54 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False

validation_seed: Optional[int] = None
validation_split: float = 0.0

@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
Expand All @@ -116,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0



@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
Expand Down Expand Up @@ -231,8 +231,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,

}

# options handled by argparse but not handled by user config
Expand Down Expand Up @@ -472,39 +475,49 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)

# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

def print_info(_datasets):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")

if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
Expand Down Expand Up @@ -534,40 +547,43 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
" ",
)

if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")

logger.info(f"{info}")

print_info(datasets)

if len(val_datasets) > 0:
print("Validation dataset")
print_info(val_datasets)

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
print(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

for i, dataset in enumerate(val_datasets):
print(f"[Validation Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return DatasetGroup(datasets)


return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)

def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
tokens = name.split("_")
Expand Down
22 changes: 22 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,20 @@
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz"

def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)

if is_train:
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
return paths[len(paths) - round(len(paths) * validation_split):]

class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
Expand Down Expand Up @@ -1416,6 +1430,7 @@ class DreamBoothDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
is_train: bool,
batch_size: int,
tokenizer,
max_token_length,
Expand All @@ -1427,12 +1442,17 @@ def __init__(
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
validation_split: float,
validation_seed: Optional[int],
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
Expand Down Expand Up @@ -1485,6 +1505,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
return [], []

img_paths = glob_images(subset.image_dir, "*")
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
Expand Down
Loading

0 comments on commit 98b43e9

Please sign in to comment.