Skip to content

Commit

Permalink
feat: refactor latent cache format
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 15, 2024
1 parent 0047bb1 commit bdac55e
Show file tree
Hide file tree
Showing 17 changed files with 360 additions and 637 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.requires_grad_(False)
vae.eval()

train_dataset_group.new_cache_latents(vae, accelerator)
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)

vae.to("cpu")
clean_memory_on_device(accelerator.device)
Expand Down
2 changes: 1 addition & 1 deletion finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def process_batch(is_last):

# バッチへ追加
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_npz = npz_file_name
image_info.latents_cache_path = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image
Expand Down
2 changes: 1 addition & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def train(args):
ae.requires_grad_(False)
ae.eval()

train_dataset_group.new_cache_latents(ae, accelerator)
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)

ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
Expand Down
Loading

0 comments on commit bdac55e

Please sign in to comment.