From f5aecf628e602d4a95c1989a65cafaede6b2eba5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Tue, 19 Mar 2024 16:03:38 +0800 Subject: [PATCH] Update train_util.py --- library/train_util.py | 32 ++++++-------------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d6dfe0772..dbeecf733 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -453,7 +453,6 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], - sample_weight: bool, ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -477,7 +476,6 @@ def __init__( self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる - self.sample_weight = sample_weight self.img_count = 0 @@ -508,7 +506,6 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, - sample_weight, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -534,7 +531,6 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, - sample_weight, ) self.is_reg = is_reg @@ -574,7 +570,6 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, - sample_weight, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -600,7 +595,6 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, - sample_weight, ) self.metadata_file = metadata_file @@ -637,7 +631,6 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, - sample_weight, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -663,7 +656,6 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, - sample_weight, ) self.conditioning_data_dir = conditioning_data_dir @@ -695,7 +687,6 @@ def __init__( self.width, self.height = (None, None) if resolution is None else resolution self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset - self.sample_weight = sample_weight self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] self.token_padding_disabled = False @@ -1271,13 +1262,12 @@ def __getitem__(self, index): image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] sample_weight = 1.0 - if subset.sample_weight is not None: - sample_weight_path = os.path.splitext(info.absolute_path)[0] + ".weight" - try: - with open(sample_weight_path, 'r', encoding='utf-8') as file: - sample_weight = float(file.readline().strip()) - except (OSError, ValueError): - pass + sample_weight_path = os.path.splitext(info.absolute_path)[0] + ".weight" + try: + with open(sample_weight_path, 'r', encoding='utf-8') as file: + sample_weight = float(file.readline().strip()) + except (OSError, ValueError): + pass loss_weights.append(sample_weight * (self.prior_loss_weight if image_info.is_reg else 1.0)) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1533,7 +1523,6 @@ def __init__( validation_split: float, validation_seed: Optional[int], debug_dataset: bool, - sample_weight: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1545,7 +1534,6 @@ def __init__( self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight - self.sample_weight = sample_weight self.latents_cache = None self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1963,7 +1951,6 @@ def __init__( subset.caption_suffix, subset.token_warmup_min, subset.token_warmup_step, - subset.sample_weight, ) db_subsets.append(db_subset) @@ -1984,7 +1971,6 @@ def __init__( validation_split, validation_seed, debug_dataset, - sample_weight, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -3514,12 +3500,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - parser.add_argument( - "--sample_weight", - action="store_true", - help="Enables the use of sample weights for images, where each weight is read from a `.weight` file corresponding to the image. / 画像に対してサンプルウェイトを使用することを可能にします。各ウェイトは、画像に対応する .weight ファイルから読み取られます。", - ) - def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable