Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
yujincheng08 authored Nov 22, 2024
1 parent 2a43710 commit dec5433
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions app/sana_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def __init__(
null_caption_token = self.tokenizer(
"", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt"
).to(self.device)
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]
self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[
0
]

def build_vae(self, config):
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype)
Expand Down Expand Up @@ -231,7 +233,9 @@ def forward(
)
for _ in range(num_images_per_prompt):
with torch.no_grad():
prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())
prompts.append(
prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()
)

# prepare text feature
if not self.config.text_encoder.chi_prompt:
Expand All @@ -246,7 +250,11 @@ def forward(
) # magic number 2: [bos], [_]

caption_token = self.tokenizer(
prompts_all, max_length=max_length_all, padding="max_length", truncation=True, return_tensors="pt"
prompts_all,
max_length=max_length_all,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(device=self.device)
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
Expand Down

0 comments on commit dec5433

Please sign in to comment.