From 2dafec58e2c6e75b9ced6bbc12a9bd10e3151d3d Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sat, 26 Aug 2023 18:55:56 -0400 Subject: [PATCH 1/3] Started working on second stage training --- training/train_muse.py | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/training/train_muse.py b/training/train_muse.py index bf8e6ede..c71167f6 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -307,7 +307,10 @@ def main(): vq_class = get_vq_model_class(config.model.vq_model.type) vq_model = vq_class.from_pretrained(config.model.vq_model.pretrained) - + if config.training.is_second_stage_training: + low_res_vq_class = get_vq_model_class(config.model.low_res_vq_model.type) + low_res_vq_model = low_res_vq_class.from_pretrained(config.model.low_res_vq_model.pretrained) + low_res_vq_model.requires_grad_(False) # Freeze the text model and VQGAN text_encoder.requires_grad_(False) vq_model.requires_grad_(False) @@ -487,6 +490,8 @@ def save_model_hook(models, weights, output_dir): if not is_pre_encode: text_encoder.to(device=accelerator.device, dtype=weight_dtype) vq_model.to(device=accelerator.device) + if config.training.is_second_stage_training: + low_res_vq_model.to(device=accelerator.device) if config.training.get("use_ema", False): ema.to(accelerator.device) @@ -562,7 +567,19 @@ def save_model_hook(models, weights, output_dir): global_step = int(os.path.basename(path).split("-")[1]) first_epoch = global_step // num_update_steps_per_epoch + @torch.no_grad() + def prepare_low_res_image_tokens( + pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], + ): + # TODO: Currently does not work with pre_encode. Fix + if is_pre_encode: + low_res_image_tokens = pixel_values_or_image_ids + else: + # Lower resolution + pixel_values_or_image_ids = F.interpolate(pixel_values_or_image_ids, (pixel_values_or_image_ids.shape[2]//2, pixel_values_or_image_ids.shape[3]//2)) + low_res_image_tokens = low_res_vq_model.get_code(pixel_values) + return low_res_image_tokens @torch.no_grad() def prepare_inputs_and_labels( pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor], @@ -647,7 +664,10 @@ def prepare_inputs_and_labels( clip_embeds, micro_conds, ) = prepare_inputs_and_labels(pixel_values, input_ids, config.training.min_masking_rate, batch=batch) - + additional_args = {} + if config.training.is_second_stage_training: + low_res_input_ids = prepare_low_res_image_tokens(pixel_values) + additional_args['low_res_input_ids'] = low_res_input_ids # log the inputs for the first step of the first epoch if global_step == 0 and epoch == 0: logger.info("Input ids: {}".format(input_ids)) @@ -660,6 +680,7 @@ def prepare_inputs_and_labels( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, cond_dropout_prob=config.training.cond_dropout_prob, + **additional_args ) loss = soft_target_cross_entropy(logits, labels, soft_targets) else: @@ -674,6 +695,7 @@ def prepare_inputs_and_labels( empty_embeds=empty_embeds, empty_cond_embeds=empty_clip_embeds, micro_conds=micro_conds, + **additional_args ) # Gather the losses across all processes for logging (if we use distributed training). @@ -798,7 +820,7 @@ def prepare_inputs_and_labels( if config.training.get("use_ema", False): ema.store(model.parameters()) ema.copy_to(model.parameters()) - + # Do 2nd stage generation of images generate_images( model, vq_model, @@ -828,7 +850,7 @@ def prepare_inputs_and_labels( # Evaluate and save checkpoint at the end of training if accelerator.is_main_process: - validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels) + validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels, prepare_low_res_image_tokens, is_second_stage_training=config.training.is_second_stage_training) save_checkpoint(model, config, accelerator, global_step) # Save the final trained checkpoint @@ -842,7 +864,7 @@ def prepare_inputs_and_labels( @torch.no_grad() -def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels, empty_embeds=None): +def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels, prepare_low_res_image_tokens, empty_embeds=None, is_second_stage_training=False): logger.info("Evaluating...") model.eval() eval_loss = 0 @@ -861,6 +883,10 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp clip_embeds, micro_conds, ) = prepare_inputs_and_labels(pixel_values, input_ids, batch=batch, is_train=False) + additional_args = {} + if is_second_stage_training: + low_res_input_ids = prepare_low_res_image_tokens(pixel_values) + additional_args["low_res_input_ids"] = low_res_input_ids _, loss = model( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, @@ -869,6 +895,7 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp loss_weight=loss_weight, empty_embeds=empty_embeds, micro_conds=micro_conds, + **additional_args ) eval_loss += loss.mean() eval_loss = eval_loss / (i + 1) From b593cc562771b822d9af0888da152bbe15ff3afa Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sat, 26 Aug 2023 21:51:23 -0400 Subject: [PATCH 2/3] Basic idea done --- muse/modeling_transformer.py | 19 +++++++++++++++++-- training/train_muse.py | 8 ++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/muse/modeling_transformer.py b/muse/modeling_transformer.py index 4a09cdb6..c72d5218 100644 --- a/muse/modeling_transformer.py +++ b/muse/modeling_transformer.py @@ -1076,8 +1076,13 @@ def forward(self, hidden_states): logits = logits.permute(0, 2, 3, 1).view(batch_size, -1, self.vocab_size) return logits +class TransformerAdapterMixin: + def __init__(self): + self.adapter = None + def add_adapter(self, adapter): + self.adapter = adapter -class MaskGitTransformer(ModelMixin, ConfigMixin): +class MaskGitTransformer(ModelMixin, ConfigMixin, TransformerAdapterMixin): _supports_gradient_checkpointing = True @register_to_config @@ -1226,6 +1231,7 @@ def forward( labels=None, label_smoothing=0.0, cond_dropout_prob=0.0, + low_res_input_ids=None, **kwargs, ): if self.config.add_cross_attention and encoder_hidden_states is None: @@ -1236,6 +1242,10 @@ def forward( if encoder_hidden_states is not None and self.config.project_encoder_hidden_states: encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + if low_res_input_ids is not None: + assert self.adapter is not None + low_res_hidden_states = self.adapter(low_res_input_ids) + encoder_hidden_states = torch.concat([encoder_hidden_states, low_res_hidden_states], dim=-1) # condition dropout for classifier free guidance if encoder_hidden_states is not None and self.training and cond_dropout_prob > 0.0: @@ -1453,7 +1463,7 @@ def generate2( return sampled_ids -class MaskGiTUViT(ModelMixin, ConfigMixin): +class MaskGiTUViT(ModelMixin, ConfigMixin, TransformerAdapterMixin): _supports_gradient_checkpointing = True @register_to_config @@ -1764,6 +1774,7 @@ def forward( empty_embeds=None, empty_cond_embeds=None, micro_conds=None, + low_res_input_ids=None ): if self.config.add_cross_attention and encoder_hidden_states is None: raise ValueError("If `add_cross_attention` is True, `encoder_hidden_states` should be provided.") @@ -1803,6 +1814,10 @@ def forward( if encoder_hidden_states is not None and self.config.project_encoder_hidden_states: encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + if low_res_input_ids is not None: + assert self.adapter is not None + low_res_hidden_states = self.adapter(low_res_input_ids) + encoder_hidden_states = torch.concat([encoder_hidden_states, low_res_hidden_states], dim=-1) if self.config.add_micro_cond_embeds: micro_cond_embeds = sinusoidal_enocde(micro_conds.flatten(), self.config.micro_cond_encode_dim) diff --git a/training/train_muse.py b/training/train_muse.py index c71167f6..b3dffcb1 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -324,6 +324,14 @@ def main(): model = model_cls.from_pretrained(config.model.pretrained_model_path) else: model = model_cls(**config.model.transformer) + + if config.training.is_second_stage_training: + adapter_model_cls = MaskGitTransformer if config.adapter_model.get("architecture", "transformer") == "transformer" else MaskGiTUViT + if config.adapter_model.get("pretrained_model_path", None) is not None: + adapter_model = adapter_model_cls.from_pretrained(config.adapter_model.pretrained_model_path) + else: + adapter_model = adapter_model_cls(**config.adapter_model.transformer) + model.add_adapter(adapter_model) mask_id = model.config.mask_token_id output_size = model.output_size From 02dccb3e42d7e556b7c495eb8386c611175c0834 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 19 Sep 2023 22:18:45 -0400 Subject: [PATCH 3/3] Resolved projection issue --- muse/modeling_transformer.py | 1 + training/train_muse.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/muse/modeling_transformer.py b/muse/modeling_transformer.py index c72d5218..9fa29a44 100644 --- a/muse/modeling_transformer.py +++ b/muse/modeling_transformer.py @@ -1817,6 +1817,7 @@ def forward( if low_res_input_ids is not None: assert self.adapter is not None low_res_hidden_states = self.adapter(low_res_input_ids) + print(low_res_hidden_states.shape, encoder_hidden_states.shape) encoder_hidden_states = torch.concat([encoder_hidden_states, low_res_hidden_states], dim=-1) if self.config.add_micro_cond_embeds: diff --git a/training/train_muse.py b/training/train_muse.py index b3dffcb1..629b3ad9 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -326,11 +326,12 @@ def main(): model = model_cls(**config.model.transformer) if config.training.is_second_stage_training: - adapter_model_cls = MaskGitTransformer if config.adapter_model.get("architecture", "transformer") == "transformer" else MaskGiTUViT + adapter_model_cls = MaskGitTransformer if config.model.adapter.get("architecture", "transformer") == "transformer" else MaskGiTUViT + assert config.model.adapter.num_vq_tokens == config.model.adapter.max_position_embeddings *16 if config.adapter_model.get("pretrained_model_path", None) is not None: - adapter_model = adapter_model_cls.from_pretrained(config.adapter_model.pretrained_model_path) + adapter_model = adapter_model_cls.from_pretrained(config.model.adapter.pretrained_model_path) else: - adapter_model = adapter_model_cls(**config.adapter_model.transformer) + adapter_model = adapter_model_cls(**config.model.adapter) model.add_adapter(adapter_model) mask_id = model.config.mask_token_id output_size = model.output_size