Skip to content

Commit

Permalink
Add validation loss
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Oct 30, 2023
1 parent 2a23713 commit ba45a62
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
4 changes: 4 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4681,6 +4681,10 @@ def __call__(self, examples):
else:
dataset = self.dataset

# If we split a dataset we will get a Subset
if type(dataset) is torch.utils.data.Subset:
dataset = dataset.dataset

# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
Expand Down
124 changes: 123 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,37 @@ def train(self, args):
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで

if args.validation_ratio > 0.0:
train_ratio = 1 - args.validation_ratio
validation_ratio = args.validation_ratio
train, val = torch.utils.data.random_split(
train_dataset_group,
[train_ratio, validation_ratio]
)
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
print(f"train images: {len(train)}, validation images: {len(val)}")
else:
train = train_dataset_group
val = []

train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
train,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

val_dataloader = torch.utils.data.DataLoader(
val,
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
Expand Down Expand Up @@ -705,6 +727,10 @@ def train(self, args):

loss_list = []
loss_total = 0.0

val_loss_list = []
val_loss_total = 0.0

del train_dataset_group

# callback for step start
Expand Down Expand Up @@ -746,6 +772,8 @@ def remove_model(old_ckpt_name):

network.on_epoch_start(text_encoder, unet)

# TRAINING

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
Expand Down Expand Up @@ -874,6 +902,92 @@ def remove_model(old_ckpt_name):
if global_step >= args.max_train_steps:
break

# VALIDATION

if len(val_dataloader) > 0:
print("Validating バリデーション処理...")

with torch.no_grad():
for val_step, batch in enumerate(val_dataloader):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample()

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]

# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)

if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight

loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

current_loss = loss.detach().item()
if epoch == 0:
val_loss_list.append(current_loss)
else:
val_loss_total -= val_loss_list[val_step]
val_loss_list[val_step] = current_loss

val_loss_total += current_loss

if len(val_dataloader) > 0:
avg_loss = val_loss_total / len(val_loss_list)

if args.logging_dir is not None:
logs = {"loss/val": avg_loss}
accelerator.log(logs, step=epoch + 1)


if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
Expand Down Expand Up @@ -996,6 +1110,14 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)

parser.add_argument(
"--validation_ratio",
type=float,
default=0.0,
help="Ratio for validation images out of the training dataset"
)

return parser


Expand Down

0 comments on commit ba45a62

Please sign in to comment.