Skip to content

Commit

Permalink
support sample generation in TI training
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 28, 2023
1 parent 57c565c commit 8270765
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,7 +2074,7 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
SCHEDLER_SCHEDULE = 'scaled_linear'


def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet):
def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None):
"""
生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない
clip skipは対応した
Expand Down Expand Up @@ -2103,8 +2103,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if args.clip_skip is None:
text_encoder_or_wrapper = text_encoder
else:
print("create wrapper")

class Wrapper():
def __init__(self, tenc) -> None:
self.tenc = tenc
Expand All @@ -2116,7 +2114,7 @@ def __call__(self, input_ids, attention_mask):
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states)
pooled_output = enc_out['pooler_output']
return encoder_hidden_states, pooled_output # 1st output is only used
return encoder_hidden_states, pooled_output # 1st output is only used

text_encoder_or_wrapper = Wrapper(text_encoder)

Expand Down Expand Up @@ -2229,12 +2227,17 @@ def __call__(self, input_ids, attention_mask):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])

image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]

ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = f"{'' if args.output_name is None else args.output_name}_{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"

image.save(os.path.join(save_dir, img_filename))

Expand Down

0 comments on commit 8270765

Please sign in to comment.