Skip to content

Commit

Permalink
Update train_util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Mar 19, 2024
1 parent a671fb0 commit f5aecf6
Showing 1 changed file with 6 additions and 26 deletions.
32 changes: 6 additions & 26 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,6 @@ def __init__(
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
sample_weight: bool,
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
Expand All @@ -477,7 +476,6 @@ def __init__(

self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
self.sample_weight = sample_weight
self.img_count = 0


Expand Down Expand Up @@ -508,7 +506,6 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
sample_weight,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

Expand All @@ -534,7 +531,6 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
sample_weight,
)

self.is_reg = is_reg
Expand Down Expand Up @@ -574,7 +570,6 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
sample_weight,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"

Expand All @@ -600,7 +595,6 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
sample_weight,
)

self.metadata_file = metadata_file
Expand Down Expand Up @@ -637,7 +631,6 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
sample_weight,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

Expand All @@ -663,7 +656,6 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
sample_weight,
)

self.conditioning_data_dir = conditioning_data_dir
Expand Down Expand Up @@ -695,7 +687,6 @@ def __init__(
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
self.debug_dataset = debug_dataset
self.sample_weight = sample_weight
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []

self.token_padding_disabled = False
Expand Down Expand Up @@ -1271,13 +1262,12 @@ def __getitem__(self, index):
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
sample_weight = 1.0
if subset.sample_weight is not None:
sample_weight_path = os.path.splitext(info.absolute_path)[0] + ".weight"
try:
with open(sample_weight_path, 'r', encoding='utf-8') as file:
sample_weight = float(file.readline().strip())
except (OSError, ValueError):
pass
sample_weight_path = os.path.splitext(info.absolute_path)[0] + ".weight"
try:
with open(sample_weight_path, 'r', encoding='utf-8') as file:
sample_weight = float(file.readline().strip())
except (OSError, ValueError):
pass
loss_weights.append(sample_weight * (self.prior_loss_weight if image_info.is_reg else 1.0)) # in case of fine tuning, is_reg is always False

flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
Expand Down Expand Up @@ -1533,7 +1523,6 @@ def __init__(
validation_split: float,
validation_seed: Optional[int],
debug_dataset: bool,
sample_weight: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)

Expand All @@ -1545,7 +1534,6 @@ def __init__(
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.sample_weight = sample_weight
self.latents_cache = None
self.enable_bucket = enable_bucket
if self.enable_bucket:
Expand Down Expand Up @@ -1963,7 +1951,6 @@ def __init__(
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
subset.sample_weight,
)
db_subsets.append(db_subset)

Expand All @@ -1984,7 +1971,6 @@ def __init__(
validation_split,
validation_seed,
debug_dataset,
sample_weight,
)

# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
Expand Down Expand Up @@ -3514,12 +3500,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)

parser.add_argument(
"--sample_weight",
action="store_true",
help="Enables the use of sample weights for images, where each weight is read from a `.weight` file corresponding to the image. / 画像に対してサンプルウェイトを使用することを可能にします。各ウェイトは、画像に対応する .weight ファイルから読み取られます。",
)

def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
Expand Down

0 comments on commit f5aecf6

Please sign in to comment.