From ab46142de967b21e031e2f7cdac82898e8a7e5f9 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 14 Jan 2024 17:09:48 -0800 Subject: [PATCH] [not sure if correct] Add tentative fixed implementation of discriminator R1 gradient penalty. --- examples/add/train_add_distill_sd_wds.py | 77 ++++++++++++++---------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 6cb39e1dc19e..c5b2245be777 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -609,26 +609,27 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - def forward( - self, - x: torch.Tensor, - c_text: torch.Tensor, - c_img: Optional[torch.Tensor] = None, - transform_positive: bool = True, - return_dict: bool = True, - ): - # TODO: do we need the augmentations from the original StyleGAN-T code? + def get_features(self, image: torch.Tensor, transform_positive: bool = True) -> Dict[str, torch.Tensor]: if transform_positive: # Transform to [0, 1]. - x = x.add(1).div(2) + image = image.add(1).div(2) # Forward pass through feature network. - features = self.feature_network(x) + features = self.feature_network(image) + return features + def forward_features( + self, + features: Dict[str, torch.Tensor], + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size = features["0"].size(0) # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(batch_size, -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -636,6 +637,18 @@ def forward( return DiscriminatorOutput(logits=logits, features=features) + def forward( + self, + image: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + features = self.get_features(image, transform_positive=transform_positive) + d_output = self.forward_features(features, c_text, c_img=c_img, return_dict=return_dict) + return d_output + def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): logger.info("Running validation... ") @@ -1777,7 +1790,7 @@ def compute_image_embeddings(image_batch, image_encoder): text_embedding = encoded_text.pop("text_embedding") image_embedding = None if args.use_image_conditioning: - image_embedding = encoded_image.pop("image_embeds") + image_embedding = encoded_image.pop("image_embeds").float() # Only supply image conditioning when student timestep is not last training timestep T. image_embedding = torch.where( student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, @@ -1842,32 +1855,34 @@ def compute_image_embeddings(image_batch, image_encoder): ) student_gen_image = torch.cat(student_gen_image, dim=0) - # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) - - # 3. Calculate the discriminator real adversarial loss terms. - d_logits_real = disc_output_real.logits + # 2. Calculate the discriminator real adversarial loss terms. + features_real = discriminator.get_features(pixel_values.float()) + for k, feature in features_real.items(): + # Required so that the torch.autograd.grad call below works properly? + feature.requires_grad_(True) + d_logits_real = discriminator.forward_features( + features_real, text_embedding.float(), image_embedding, return_dict=False + )[0] # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) - # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real - # data. + # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the + # discriminator head input features from the real data. d_r1_regularizer = 0 - for k, head in discriminator.heads.items(): - head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True - ) - head_grad_norm = 0 - for grad in head_grad_params: - head_grad_norm += grad.pow(2).sum() - head_grad_norm = head_grad_norm.sqrt() - d_r1_regularizer += head_grad_norm + grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, + inputs=features_real.values(), + create_graph=True, + ) + for grad in grad_params: + d_r1_regularizer += grad.pow(2).sum() + d_r1_regularizer = d_r1_regularizer.sqrt() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) - # 5. Calculate the discriminator fake adversarial loss terms. + # 4. Calculate the discriminator fake adversarial loss terms. + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) d_logits_fake = disc_output_fake.logits # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake))