From 6893b27483b4a8e27587bee53572831f6796bf70 Mon Sep 17 00:00:00 2001 From: Andy W <37781802+aandyw@users.noreply.github.com> Date: Wed, 27 Dec 2023 04:55:19 -0500 Subject: [PATCH] Fix "push_to_hub only create repo in consistency model lora SDXL training script" (#6102) * fix * style fix --------- Co-authored-by: Sayak Paul --- .../train_lcm_distill_lora_sd_wds.py | 12 ++++++++++-- .../train_lcm_distill_lora_sdxl_wds.py | 12 ++++++++++-- .../train_lcm_distill_sd_wds.py | 12 ++++++++++-- .../train_lcm_distill_sdxl_wds.py | 12 ++++++++++-- 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 05689b71fa047..c85e2c462b047 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -38,7 +38,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from braceexpand import braceexpand -from huggingface_hub import create_repo +from huggingface_hub import create_repo, upload_folder from packaging import version from peft import LoraConfig, get_peft_model, get_peft_model_state_dict from torch.utils.data import default_collate @@ -847,7 +847,7 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: - create_repo( + repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token, @@ -1366,6 +1366,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + accelerator.end_training() diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 014a770fa0ba3..75671c18c5e09 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -39,7 +39,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from braceexpand import braceexpand -from huggingface_hub import create_repo +from huggingface_hub import create_repo, upload_folder from packaging import version from peft import LoraConfig, get_peft_model, get_peft_model_state_dict from torch.utils.data import default_collate @@ -842,7 +842,7 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: - create_repo( + repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token, @@ -1424,6 +1424,14 @@ def compute_embeddings( lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + accelerator.end_training() diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 9e8df77aacbd3..b2085b7044ba8 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -38,7 +38,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from braceexpand import braceexpand -from huggingface_hub import create_repo +from huggingface_hub import create_repo, upload_folder from packaging import version from torch.utils.data import default_collate from torchvision import transforms @@ -835,7 +835,7 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: - create_repo( + repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token, @@ -1354,6 +1354,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok target_unet = accelerator.unwrap_model(target_unet) target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target")) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + accelerator.end_training() diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index da49649c918c8..ee86def673fae 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -39,7 +39,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from braceexpand import braceexpand -from huggingface_hub import create_repo +from huggingface_hub import create_repo, upload_folder from packaging import version from torch.utils.data import default_collate from torchvision import transforms @@ -875,7 +875,7 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: - create_repo( + repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token, @@ -1457,6 +1457,14 @@ def compute_embeddings( target_unet = accelerator.unwrap_model(target_unet) target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target")) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + accelerator.end_training()