diff --git a/app/sana_pipeline.py b/app/sana_pipeline.py index 5a20f14..cfbec8e 100644 --- a/app/sana_pipeline.py +++ b/app/sana_pipeline.py @@ -117,10 +117,13 @@ def __init__( self.model = self.build_sana_model(config).to(self.device) # 3. pre-compute null embedding - null_caption_token = self.tokenizer( - "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt" - ).to(self.device) - self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0] + with torch.no_grad(): + null_caption_token = self.tokenizer( + "", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt" + ).to(self.device) + self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[ + 0 + ] def build_vae(self, config): vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype) @@ -155,6 +158,8 @@ def build_sana_model(self, config): "use_fp32_attention": config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16", } model = build_model(config.model.model, **model_kwargs) + model = model.to(self.weight_dtype) + self.logger.info(f"use_fp32_attention: {model.fp32_attention}") self.logger.info( f"{model.__class__.__name__}:{config.model.model}," @@ -227,83 +232,90 @@ def forward( torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1), ) for _ in range(num_images_per_prompt): - prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()) - - # prepare text feature - if not self.config.text_encoder.chi_prompt: - max_length_all = self.config.text_encoder.model_max_length - prompts_all = prompts - else: - chi_prompt = "\n".join(self.config.text_encoder.chi_prompt) - prompts_all = [chi_prompt + prompt for prompt in prompts] - num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) - max_length_all = ( - num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2 - ) # magic number 2: [bos], [_] - - caption_token = self.tokenizer( - prompts_all, max_length=max_length_all, padding="max_length", truncation=True, return_tensors="pt" - ).to(self.device) - select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0)) - caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][ - :, :, select_index - ].to(self.weight_dtype) - emb_masks = caption_token.attention_mask[:, select_index] - null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype) - - # start sampling - with torch.no_grad(): - n = len(prompts) - if latents is None: - z = torch.randn( - n, - self.config.vae.vae_latent_dim, - self.latent_size_h, - self.latent_size_w, - generator=generator, - device=self.device, - dtype=self.weight_dtype, - ) - else: - z = latents.to(self.weight_dtype).to(self.device) - model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) - if self.vis_sampler == "flow_euler": - flow_solver = FlowEuler( - self.model, - condition=caption_embs, - uncondition=null_y, - cfg_scale=guidance_scale, - model_kwargs=model_kwargs, - ) - sample = flow_solver.sample( - z, - steps=num_inference_steps, - ) - elif self.vis_sampler == "flow_dpm-solver": - scheduler = DPMS( - self.model, - condition=caption_embs, - uncondition=null_y, - guidance_type=self.guidance_type, - cfg_scale=guidance_scale, - pag_scale=pag_guidance_scale, - pag_applied_layers=self.config.model.pag_applied_layers, - model_type="flow", - model_kwargs=model_kwargs, - schedule="FLOW", - ) - scheduler.register_progress_bar(self.progress_fn) - sample = scheduler.sample( - z, - steps=num_inference_steps, - order=2, - skip_type="time_uniform_flow", - method="multistep", - flow_shift=self.flow_shift, + with torch.no_grad(): + prompts.append( + prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip() ) + # prepare text feature + if not self.config.text_encoder.chi_prompt: + max_length_all = self.config.text_encoder.model_max_length + prompts_all = prompts + else: + chi_prompt = "\n".join(self.config.text_encoder.chi_prompt) + prompts_all = [chi_prompt + prompt for prompt in prompts] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = ( + num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2 + ) # magic number 2: [bos], [_] + + caption_token = self.tokenizer( + prompts_all, + max_length=max_length_all, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(device=self.device) + select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0)) + caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][ + :, :, select_index + ].to(self.weight_dtype) + emb_masks = caption_token.attention_mask[:, select_index] + null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype) + + n = len(prompts) + if latents is None: + z = torch.randn( + n, + self.config.vae.vae_latent_dim, + self.latent_size_h, + self.latent_size_w, + generator=generator, + device=self.device, + dtype=self.weight_dtype, + ) + else: + z = latents.to(self.weight_dtype).to(self.device) + model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) + if self.vis_sampler == "flow_euler": + flow_solver = FlowEuler( + self.model, + condition=caption_embs, + uncondition=null_y, + cfg_scale=guidance_scale, + model_kwargs=model_kwargs, + ) + sample = flow_solver.sample( + z, + steps=num_inference_steps, + ) + elif self.vis_sampler == "flow_dpm-solver": + scheduler = DPMS( + self.model, + condition=caption_embs, + uncondition=null_y, + guidance_type=self.guidance_type, + cfg_scale=guidance_scale, + pag_scale=pag_guidance_scale, + pag_applied_layers=self.config.model.pag_applied_layers, + model_type="flow", + model_kwargs=model_kwargs, + schedule="FLOW", + ) + scheduler.register_progress_bar(self.progress_fn) + sample = scheduler.sample( + z, + steps=num_inference_steps, + order=2, + skip_type="time_uniform_flow", + method="multistep", + flow_shift=self.flow_shift, + ) + sample = sample.to(self.weight_dtype) - sample = vae_decode(self.config.vae.vae_type, self.vae, sample) + with torch.no_grad(): + sample = vae_decode(self.config.vae.vae_type, self.vae, sample) + sample = resize_and_crop_tensor(sample, self.ori_width, self.ori_height) samples.append(sample)