From 71734b4cb2efcaeb89fe35adaad65e06ee756bf7 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Tue, 19 Mar 2024 21:20:01 +0900 Subject: [PATCH] support tcd --- modules/lcm/lcm_trainer.py | 42 ++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/modules/lcm/lcm_trainer.py b/modules/lcm/lcm_trainer.py index 115cda8..fc724e2 100644 --- a/modules/lcm/lcm_trainer.py +++ b/modules/lcm/lcm_trainer.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from modules.trainer import BaseTrainer -from modules.scheduler import BaseScheduler +from modules.scheduler import BaseScheduler, substitution_t # additional config # trainer: @@ -17,15 +17,34 @@ def step(self, sample, model_output, t, prev_t, use_ddim=False): pred_original_sample = self.pred_original_sample(sample, model_output, t) if use_ddim: # for training step noise = self.pred_noise(sample, model_output, t) - else: + elif self.gamma == 0.0: noise = torch.randn_like(sample) + else: # tcd + inner_t = ((1 - self.gamma) * prev_t).round().long() + output_inner = self.step(sample, model_output, t, inner_t, use_ddim=True) + noise_random = torch.randn_like(sample) + return self.add_noise_inner(output_inner, noise_random, inner_t, prev_t) + return self.add_noise(pred_original_sample, noise, prev_t) + + # x_current_t -> x_target_t + def add_noise_inner(self, sample, noise, current_t, target_t): + alphas_bar_current = substitution_t(self.alphas_bar, current_t, sample.shape[0]) + alphas_bar_target = substitution_t(self.alphas_bar, target_t, sample.shape[0]) + + alphas_bar = alphas_bar_target / alphas_bar_current + + return alphas_bar.sqrt() * sample + (1 - alphas_bar).sqrt() * noise + class LCMTrainer(BaseTrainer): def __init__(self, config, diffusion, text_model, vae, scheduler, network): super().__init__(config, diffusion, text_model, vae, scheduler, network) self.scheduler = LCMScheduler(self.scheduler.v_prediction) # overwrite - + self.tcd = config.additional_conf.lcm.get("tcd", False) + gamma = 0.3 if self.tcd else 0.0 + setattr(self.scheduler, "gamma", gamma) + def prepare_modules_for_training(self, device="cuda"): super().prepare_modules_for_training(device) @@ -57,8 +76,12 @@ def loss(self, batch): num_inference_steps = self.config.additional_conf.lcm.num_inference_steps interval = 1000 // num_inference_steps timesteps = torch.randint(interval, 1000, (self.batch_size,), device=latents.device) - prev_timesteps = timesteps - interval + if self.tcd: + inner_timesteps = [] + for t in prev_timesteps: + inner_timesteps.append(torch.randint(0, t+1, (1,), device=latents.device)) + inner_timesteps = torch.cat(inner_timesteps) noise = torch.randn_like(latents) noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) @@ -66,7 +89,7 @@ def loss(self, batch): with torch.autocast("cuda", dtype=self.autocast_dtype): model_output = self.diffusion(noisy_latents, timesteps, encoder_hidden_states, pooled_output, size_condition) pred_original_sample = self.scheduler.pred_original_sample(noisy_latents, model_output, timesteps) - + pred = self.scheduler.step(noisy_latents, model_output, timesteps, inner_timesteps, use_ddim=True) if self.tcd else pred_original_sample with torch.no_grad(): # one step ddim negative_encoder_hidden_states = self.negative_encoder_hidden_states.repeat(self.batch_size, 1, 1) @@ -81,8 +104,11 @@ def loss(self, batch): # target target_model_output = self.diffusion(prev_noisy_latents, prev_timesteps, encoder_hidden_states, pooled_output, size_condition) - target_original_sample = self.scheduler.pred_original_sample(prev_noisy_latents, target_model_output, prev_timesteps) + if self.tcd: + target = self.scheduler.step(prev_noisy_latents, target_model_output, prev_timesteps, inner_timesteps, use_ddim=True) + else: + target = self.scheduler.pred_original_sample(prev_noisy_latents, target_model_output, prev_timesteps) - loss = nn.functional.mse_loss(pred_original_sample.float(), target_original_sample.float(), reduction="mean") + loss = nn.functional.mse_loss(pred.float(), target.float(), reduction="mean") - return loss \ No newline at end of file + return loss