diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..ab90fb63b 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -51,6 +51,7 @@ class BaseSubsetParams: image_dir: Optional[str] = None num_repeats: int = 1 shuffle_caption: bool = False + caption_separator: str = ',', keep_tokens: int = 0 color_aug: bool = False flip_aug: bool = False diff --git a/library/train_util.py b/library/train_util.py index cc9ac4555..9fb616ed6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -349,6 +349,7 @@ def __init__( image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, + caption_separator: str, keep_tokens: int, color_aug: bool, flip_aug: bool, @@ -365,6 +366,7 @@ def __init__( self.image_dir = image_dir self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption + self.caption_separator = caption_separator self.keep_tokens = keep_tokens self.color_aug = color_aug self.flip_aug = flip_aug @@ -391,6 +393,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, + caption_separator: str, keep_tokens, color_aug, flip_aug, @@ -410,6 +413,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -443,6 +447,7 @@ def __init__( metadata_file: str, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -462,6 +467,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -492,6 +498,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -511,6 +518,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -646,7 +654,7 @@ def process_caption(self, subset: BaseSubset, caption): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: @@ -3105,7 +3113,10 @@ def add_dataset_arguments( # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument( + "--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字" ) parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"