Skip to content

Commit

Permalink
[advanced flux training] bug fix + reduce memory cost as in #9829 (#9838
Browse files Browse the repository at this point in the history
)

* memory improvement as done here: #9829

* fix bug

* fix bug

* style

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
linoytsaban and sayakpaul authored Nov 19, 2024
1 parent 03bf77c commit acf479b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
elems_to_repeat = 1
if freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
Expand All @@ -2168,17 +2169,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
max_sequence_length=args.max_sequence_length,
add_special_tokens=add_special_tokens_t5,
)
else:
elems_to_repeat = len(prompts)

if not freeze_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts,
)

# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
Expand Down Expand Up @@ -2371,6 +2376,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
epoch=epoch,
torch_dtype=weight_dtype,
)
images = None
del pipeline

if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
Expand Down Expand Up @@ -2448,6 +2456,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline

accelerator.end_training()

Expand Down
6 changes: 5 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,11 +1648,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
prompt=prompts,
)
else:
elems_to_repeat = len(prompts)
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
text_input_ids_list=[
tokens_one.repeat(elems_to_repeat, 1),
tokens_two.repeat(elems_to_repeat, 1),
],
max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=args.instance_prompt,
Expand Down

0 comments on commit acf479b

Please sign in to comment.