Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Second stage training #107

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,8 +1076,13 @@ def forward(self, hidden_states):
logits = logits.permute(0, 2, 3, 1).view(batch_size, -1, self.vocab_size)
return logits

class TransformerAdapterMixin:
def __init__(self):
self.adapter = None
def add_adapter(self, adapter):
self.adapter = adapter

class MaskGitTransformer(ModelMixin, ConfigMixin):
class MaskGitTransformer(ModelMixin, ConfigMixin, TransformerAdapterMixin):
_supports_gradient_checkpointing = True

@register_to_config
Expand Down Expand Up @@ -1226,6 +1231,7 @@ def forward(
labels=None,
label_smoothing=0.0,
cond_dropout_prob=0.0,
low_res_input_ids=None,
**kwargs,
):
if self.config.add_cross_attention and encoder_hidden_states is None:
Expand All @@ -1236,6 +1242,10 @@ def forward(
if encoder_hidden_states is not None and self.config.project_encoder_hidden_states:
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
if low_res_input_ids is not None:
assert self.adapter is not None
low_res_hidden_states = self.adapter(low_res_input_ids)
encoder_hidden_states = torch.concat([encoder_hidden_states, low_res_hidden_states], dim=-1)

# condition dropout for classifier free guidance
if encoder_hidden_states is not None and self.training and cond_dropout_prob > 0.0:
Expand Down Expand Up @@ -1453,7 +1463,7 @@ def generate2(
return sampled_ids


class MaskGiTUViT(ModelMixin, ConfigMixin):
class MaskGiTUViT(ModelMixin, ConfigMixin, TransformerAdapterMixin):
_supports_gradient_checkpointing = True

@register_to_config
Expand Down Expand Up @@ -1764,6 +1774,7 @@ def forward(
empty_embeds=None,
empty_cond_embeds=None,
micro_conds=None,
low_res_input_ids=None
):
if self.config.add_cross_attention and encoder_hidden_states is None:
raise ValueError("If `add_cross_attention` is True, `encoder_hidden_states` should be provided.")
Expand Down Expand Up @@ -1803,6 +1814,11 @@ def forward(
if encoder_hidden_states is not None and self.config.project_encoder_hidden_states:
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
if low_res_input_ids is not None:
assert self.adapter is not None
low_res_hidden_states = self.adapter(low_res_input_ids)
print(low_res_hidden_states.shape, encoder_hidden_states.shape)
encoder_hidden_states = torch.concat([encoder_hidden_states, low_res_hidden_states], dim=-1)

if self.config.add_micro_cond_embeds:
micro_cond_embeds = sinusoidal_enocde(micro_conds.flatten(), self.config.micro_cond_encode_dim)
Expand Down
46 changes: 41 additions & 5 deletions training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ def main():

vq_class = get_vq_model_class(config.model.vq_model.type)
vq_model = vq_class.from_pretrained(config.model.vq_model.pretrained)

if config.training.is_second_stage_training:
low_res_vq_class = get_vq_model_class(config.model.low_res_vq_model.type)
low_res_vq_model = low_res_vq_class.from_pretrained(config.model.low_res_vq_model.pretrained)
low_res_vq_model.requires_grad_(False)
# Freeze the text model and VQGAN
text_encoder.requires_grad_(False)
vq_model.requires_grad_(False)
Expand All @@ -321,6 +324,15 @@ def main():
model = model_cls.from_pretrained(config.model.pretrained_model_path)
else:
model = model_cls(**config.model.transformer)

if config.training.is_second_stage_training:
adapter_model_cls = MaskGitTransformer if config.model.adapter.get("architecture", "transformer") == "transformer" else MaskGiTUViT
assert config.model.adapter.num_vq_tokens == config.model.adapter.max_position_embeddings *16
if config.adapter_model.get("pretrained_model_path", None) is not None:
adapter_model = adapter_model_cls.from_pretrained(config.model.adapter.pretrained_model_path)
else:
adapter_model = adapter_model_cls(**config.model.adapter)
model.add_adapter(adapter_model)
mask_id = model.config.mask_token_id
output_size = model.output_size

Expand Down Expand Up @@ -487,6 +499,8 @@ def save_model_hook(models, weights, output_dir):
if not is_pre_encode:
text_encoder.to(device=accelerator.device, dtype=weight_dtype)
vq_model.to(device=accelerator.device)
if config.training.is_second_stage_training:
low_res_vq_model.to(device=accelerator.device)
if config.training.get("use_ema", False):
ema.to(accelerator.device)

Expand Down Expand Up @@ -562,7 +576,19 @@ def save_model_hook(models, weights, output_dir):

global_step = int(os.path.basename(path).split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
@torch.no_grad()
def prepare_low_res_image_tokens(
pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor],
):
# TODO: Currently does not work with pre_encode. Fix
if is_pre_encode:
low_res_image_tokens = pixel_values_or_image_ids
else:
# Lower resolution
pixel_values_or_image_ids = F.interpolate(pixel_values_or_image_ids, (pixel_values_or_image_ids.shape[2]//2, pixel_values_or_image_ids.shape[3]//2))
low_res_image_tokens = low_res_vq_model.get_code(pixel_values)

return low_res_image_tokens
@torch.no_grad()
def prepare_inputs_and_labels(
pixel_values_or_image_ids: Union[torch.FloatTensor, torch.LongTensor],
Expand Down Expand Up @@ -647,7 +673,10 @@ def prepare_inputs_and_labels(
clip_embeds,
micro_conds,
) = prepare_inputs_and_labels(pixel_values, input_ids, config.training.min_masking_rate, batch=batch)

additional_args = {}
if config.training.is_second_stage_training:
low_res_input_ids = prepare_low_res_image_tokens(pixel_values)
additional_args['low_res_input_ids'] = low_res_input_ids
# log the inputs for the first step of the first epoch
if global_step == 0 and epoch == 0:
logger.info("Input ids: {}".format(input_ids))
Expand All @@ -660,6 +689,7 @@ def prepare_inputs_and_labels(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
cond_dropout_prob=config.training.cond_dropout_prob,
**additional_args
)
loss = soft_target_cross_entropy(logits, labels, soft_targets)
else:
Expand All @@ -674,6 +704,7 @@ def prepare_inputs_and_labels(
empty_embeds=empty_embeds,
empty_cond_embeds=empty_clip_embeds,
micro_conds=micro_conds,
**additional_args
)

# Gather the losses across all processes for logging (if we use distributed training).
Expand Down Expand Up @@ -798,7 +829,7 @@ def prepare_inputs_and_labels(
if config.training.get("use_ema", False):
ema.store(model.parameters())
ema.copy_to(model.parameters())

# Do 2nd stage generation of images
generate_images(
model,
vq_model,
Expand Down Expand Up @@ -828,7 +859,7 @@ def prepare_inputs_and_labels(

# Evaluate and save checkpoint at the end of training
if accelerator.is_main_process:
validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels)
validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels, prepare_low_res_image_tokens, is_second_stage_training=config.training.is_second_stage_training)
save_checkpoint(model, config, accelerator, global_step)

# Save the final trained checkpoint
Expand All @@ -842,7 +873,7 @@ def prepare_inputs_and_labels(


@torch.no_grad()
def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels, empty_embeds=None):
def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels, prepare_low_res_image_tokens, empty_embeds=None, is_second_stage_training=False):
logger.info("Evaluating...")
model.eval()
eval_loss = 0
Expand All @@ -861,6 +892,10 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp
clip_embeds,
micro_conds,
) = prepare_inputs_and_labels(pixel_values, input_ids, batch=batch, is_train=False)
additional_args = {}
if is_second_stage_training:
low_res_input_ids = prepare_low_res_image_tokens(pixel_values)
additional_args["low_res_input_ids"] = low_res_input_ids
_, loss = model(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
Expand All @@ -869,6 +904,7 @@ def validate_model(model, eval_dataloader, accelerator, global_step, prepare_inp
loss_weight=loss_weight,
empty_embeds=empty_embeds,
micro_conds=micro_conds,
**additional_args
)
eval_loss += loss.mean()
eval_loss = eval_loss / (i + 1)
Expand Down